diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 09d525311713d..94068ec7b7357 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -93,7 +93,14 @@ import com.facebook.presto.operator.aggregation.histogram.Histogram; import com.facebook.presto.operator.aggregation.multimapagg.AlternativeMultimapAggregationFunction; import com.facebook.presto.operator.aggregation.multimapagg.MultimapAggregationFunction; +import com.facebook.presto.operator.aggregation.noisyaggregation.NoisyApproximateDistinctCountSfmAggregation; +import com.facebook.presto.operator.aggregation.noisyaggregation.NoisyApproximateDistinctCountSfmAggregationDefaultBucketsPrecision; +import com.facebook.presto.operator.aggregation.noisyaggregation.NoisyApproximateDistinctCountSfmAggregationDefaultPrecision; +import com.facebook.presto.operator.aggregation.noisyaggregation.NoisyApproximateSetSfmAggregation; +import com.facebook.presto.operator.aggregation.noisyaggregation.NoisyApproximateSetSfmAggregationDefaultBucketsPrecision; +import com.facebook.presto.operator.aggregation.noisyaggregation.NoisyApproximateSetSfmAggregationDefaultPrecision; import com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountIfGaussianAggregation; +import com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchMergeAggregation; import com.facebook.presto.operator.scalar.ArrayAllMatchFunction; import com.facebook.presto.operator.scalar.ArrayAnyMatchFunction; import com.facebook.presto.operator.scalar.ArrayCardinalityFunction; @@ -173,6 +180,7 @@ import com.facebook.presto.operator.scalar.RepeatFunction; import com.facebook.presto.operator.scalar.SequenceFunction; import com.facebook.presto.operator.scalar.SessionFunctions; +import com.facebook.presto.operator.scalar.SfmSketchFunctions; import com.facebook.presto.operator.scalar.SplitToMapFunction; import com.facebook.presto.operator.scalar.SplitToMultimapFunction; import com.facebook.presto.operator.scalar.StringFunctions; @@ -240,6 +248,7 @@ import com.facebook.presto.type.MapParametricType; import com.facebook.presto.type.QuantileDigestOperators; import com.facebook.presto.type.RealOperators; +import com.facebook.presto.type.SfmSketchOperators; import com.facebook.presto.type.SmallintOperators; import com.facebook.presto.type.TDigestOperators; import com.facebook.presto.type.TimeOperators; @@ -489,6 +498,7 @@ import static com.facebook.presto.type.MapParametricType.MAP; import static com.facebook.presto.type.Re2JRegexpType.RE2J_REGEXP; import static com.facebook.presto.type.RowParametricType.ROW; +import static com.facebook.presto.type.SfmSketchType.SFM_SKETCH; import static com.facebook.presto.type.TypeUtils.resolveTypes; import static com.facebook.presto.type.khyperloglog.KHyperLogLogType.K_HYPER_LOG_LOG; import static com.facebook.presto.type.setdigest.SetDigestType.SET_DIGEST; @@ -611,6 +621,7 @@ private void registerBuiltInTypes() addType(SET_DIGEST); addType(K_HYPER_LOG_LOG); addType(P4_HYPER_LOG_LOG); + addType(SFM_SKETCH); addType(JONI_REGEXP); addType(RE2J_REGEXP); addType(LIKE_PATTERN); @@ -654,6 +665,12 @@ private List getBuildInFunctions(FeaturesConfig featuresC .window(LeadFunction.class) .aggregate(ApproximateCountDistinctAggregation.class) .aggregate(DefaultApproximateCountDistinctAggregation.class) + .aggregate(NoisyApproximateSetSfmAggregation.class) + .aggregate(NoisyApproximateSetSfmAggregationDefaultBucketsPrecision.class) + .aggregate(NoisyApproximateSetSfmAggregationDefaultPrecision.class) + .aggregate(NoisyApproximateDistinctCountSfmAggregation.class) + .aggregate(NoisyApproximateDistinctCountSfmAggregationDefaultBucketsPrecision.class) + .aggregate(NoisyApproximateDistinctCountSfmAggregationDefaultPrecision.class) .aggregate(SumDataSizeForStats.class) .aggregate(MaxDataSizeForStats.class) .aggregate(ConvexHullAggregation.class) @@ -688,6 +705,7 @@ private List getBuildInFunctions(FeaturesConfig featuresC .aggregates(GeometricMeanAggregations.class) .aggregates(RealGeometricMeanAggregations.class) .aggregates(MergeHyperLogLogAggregation.class) + .aggregates(SfmSketchMergeAggregation.class) .aggregates(ApproximateSetAggregation.class) .functions(QDIGEST_AGG, QDIGEST_AGG_WITH_WEIGHT, QDIGEST_AGG_WITH_WEIGHT_AND_ERROR) .function(MergeQuantileDigestFunction.MERGE) @@ -741,6 +759,7 @@ private List getBuildInFunctions(FeaturesConfig featuresC .scalars(KdbTreeCasts.class) .scalar(ColorOperators.ColorDistinctFromOperator.class) .scalars(HyperLogLogFunctions.class) + .scalars(SfmSketchFunctions.class) .scalars(QuantileDigestFunctions.class) .scalars(UnknownOperators.class) .scalar(UnknownOperators.UnknownDistinctFromOperator.class) @@ -778,6 +797,7 @@ private List getBuildInFunctions(FeaturesConfig featuresC .scalar(TimestampWithTimeZoneOperators.TimestampWithTimeZoneDistinctFromOperator.class) .scalars(DateTimeOperators.class) .scalars(HyperLogLogOperators.class) + .scalars(SfmSketchOperators.class) .scalars(QuantileDigestOperators.class) .scalars(IpAddressOperators.class) .scalar(IpAddressOperators.IpAddressDistinctFromOperator.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateDistinctCountSfmAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateDistinctCountSfmAggregation.java new file mode 100644 index 0000000000000..05fdb290c9241 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateDistinctCountSfmAggregation.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OperatorDependency; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import io.airlift.slice.Slice; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.addHashToSketch; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashDouble; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashLong; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashSlice; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.mergeStates; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.writeCardinality; + +@AggregationFunction(value = "noisy_approx_distinct_sfm") +public final class NoisyApproximateDistinctCountSfmAggregation +{ + private NoisyApproximateDistinctCountSfmAggregation() {} + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") long value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets, + @SqlType(StandardTypes.BIGINT) long precision) + { + addHashToSketch(state, hashLong(methodHandle, value), epsilon, numberOfBuckets, precision); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") double value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets, + @SqlType(StandardTypes.BIGINT) long precision) + { + addHashToSketch(state, hashDouble(methodHandle, value), epsilon, numberOfBuckets, precision); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") Slice value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets, + @SqlType(StandardTypes.BIGINT) long precision) + { + addHashToSketch(state, hashSlice(methodHandle, value), epsilon, numberOfBuckets, precision); + } + + @CombineFunction + public static void combineState(@AggregationState SfmSketchState state, @AggregationState SfmSketchState otherState) + { + mergeStates(state, otherState); + } + + @OutputFunction(StandardTypes.BIGINT) + public static void evaluateFinal(@AggregationState SfmSketchState state, BlockBuilder out) + { + writeCardinality(state, out); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateDistinctCountSfmAggregationDefaultBucketsPrecision.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateDistinctCountSfmAggregationDefaultBucketsPrecision.java new file mode 100644 index 0000000000000..0cd683a5ec889 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateDistinctCountSfmAggregationDefaultBucketsPrecision.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OperatorDependency; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import io.airlift.slice.Slice; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.DEFAULT_BUCKET_COUNT; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.DEFAULT_PRECISION; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.addHashToSketch; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashDouble; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashLong; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashSlice; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.mergeStates; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.writeCardinality; + +@AggregationFunction(value = "noisy_approx_distinct_sfm") +public final class NoisyApproximateDistinctCountSfmAggregationDefaultBucketsPrecision +{ + private NoisyApproximateDistinctCountSfmAggregationDefaultBucketsPrecision() {} + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") long value, + @SqlType(StandardTypes.DOUBLE) double epsilon) + { + addHashToSketch(state, hashLong(methodHandle, value), epsilon, DEFAULT_BUCKET_COUNT, DEFAULT_PRECISION); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") double value, + @SqlType(StandardTypes.DOUBLE) double epsilon) + { + addHashToSketch(state, hashDouble(methodHandle, value), epsilon, DEFAULT_BUCKET_COUNT, DEFAULT_PRECISION); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") Slice value, + @SqlType(StandardTypes.DOUBLE) double epsilon) + { + addHashToSketch(state, hashSlice(methodHandle, value), epsilon, DEFAULT_BUCKET_COUNT, DEFAULT_PRECISION); + } + + @CombineFunction + public static void combineState(@AggregationState SfmSketchState state, @AggregationState SfmSketchState otherState) + { + mergeStates(state, otherState); + } + + @OutputFunction(StandardTypes.BIGINT) + public static void evaluateFinal(@AggregationState SfmSketchState state, BlockBuilder out) + { + writeCardinality(state, out); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateDistinctCountSfmAggregationDefaultPrecision.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateDistinctCountSfmAggregationDefaultPrecision.java new file mode 100644 index 0000000000000..d0042f4933974 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateDistinctCountSfmAggregationDefaultPrecision.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OperatorDependency; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import io.airlift.slice.Slice; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.DEFAULT_PRECISION; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.addHashToSketch; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashDouble; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashLong; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashSlice; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.mergeStates; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.writeCardinality; + +@AggregationFunction(value = "noisy_approx_distinct_sfm") +public final class NoisyApproximateDistinctCountSfmAggregationDefaultPrecision +{ + private NoisyApproximateDistinctCountSfmAggregationDefaultPrecision() {} + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") long value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets) + { + addHashToSketch(state, hashLong(methodHandle, value), epsilon, numberOfBuckets, DEFAULT_PRECISION); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") double value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets) + { + addHashToSketch(state, hashDouble(methodHandle, value), epsilon, numberOfBuckets, DEFAULT_PRECISION); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") Slice value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets) + { + addHashToSketch(state, hashSlice(methodHandle, value), epsilon, numberOfBuckets, DEFAULT_PRECISION); + } + + @CombineFunction + public static void combineState(@AggregationState SfmSketchState state, @AggregationState SfmSketchState otherState) + { + mergeStates(state, otherState); + } + + @OutputFunction(StandardTypes.BIGINT) + public static void evaluateFinal(@AggregationState SfmSketchState state, BlockBuilder out) + { + writeCardinality(state, out); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateSetSfmAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateSetSfmAggregation.java new file mode 100644 index 0000000000000..da555feaa4c5e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateSetSfmAggregation.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OperatorDependency; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.type.SfmSketchType; +import io.airlift.slice.Slice; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.addHashToSketch; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashDouble; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashLong; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashSlice; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.mergeStates; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.writeSketch; + +@AggregationFunction(value = "noisy_approx_set_sfm") +public final class NoisyApproximateSetSfmAggregation +{ + private NoisyApproximateSetSfmAggregation() {} + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") long value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets, + @SqlType(StandardTypes.BIGINT) long precision) + { + addHashToSketch(state, hashLong(methodHandle, value), epsilon, numberOfBuckets, precision); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") double value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets, + @SqlType(StandardTypes.BIGINT) long precision) + { + addHashToSketch(state, hashDouble(methodHandle, value), epsilon, numberOfBuckets, precision); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") Slice value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets, + @SqlType(StandardTypes.BIGINT) long precision) + { + addHashToSketch(state, hashSlice(methodHandle, value), epsilon, numberOfBuckets, precision); + } + + @CombineFunction + public static void combineState(@AggregationState SfmSketchState state, @AggregationState SfmSketchState otherState) + { + mergeStates(state, otherState); + } + + @OutputFunction(SfmSketchType.NAME) + public static void evaluateFinal(@AggregationState SfmSketchState state, BlockBuilder out) + { + writeSketch(state, out); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateSetSfmAggregationDefaultBucketsPrecision.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateSetSfmAggregationDefaultBucketsPrecision.java new file mode 100644 index 0000000000000..66bd69e465ca6 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateSetSfmAggregationDefaultBucketsPrecision.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OperatorDependency; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.type.SfmSketchType; +import io.airlift.slice.Slice; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.DEFAULT_BUCKET_COUNT; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.DEFAULT_PRECISION; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.addHashToSketch; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashDouble; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashLong; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashSlice; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.mergeStates; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.writeSketch; + +@AggregationFunction(value = "noisy_approx_set_sfm") +public final class NoisyApproximateSetSfmAggregationDefaultBucketsPrecision +{ + private NoisyApproximateSetSfmAggregationDefaultBucketsPrecision() {} + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") long value, + @SqlType(StandardTypes.DOUBLE) double epsilon) + { + addHashToSketch(state, hashLong(methodHandle, value), epsilon, DEFAULT_BUCKET_COUNT, DEFAULT_PRECISION); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") double value, + @SqlType(StandardTypes.DOUBLE) double epsilon) + { + addHashToSketch(state, hashDouble(methodHandle, value), epsilon, DEFAULT_BUCKET_COUNT, DEFAULT_PRECISION); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") Slice value, + @SqlType(StandardTypes.DOUBLE) double epsilon) + { + addHashToSketch(state, hashSlice(methodHandle, value), epsilon, DEFAULT_BUCKET_COUNT, DEFAULT_PRECISION); + } + + @CombineFunction + public static void combineState(@AggregationState SfmSketchState state, @AggregationState SfmSketchState otherState) + { + mergeStates(state, otherState); + } + + @OutputFunction(SfmSketchType.NAME) + public static void evaluateFinal(@AggregationState SfmSketchState state, BlockBuilder out) + { + writeSketch(state, out); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateSetSfmAggregationDefaultPrecision.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateSetSfmAggregationDefaultPrecision.java new file mode 100644 index 0000000000000..6a6d9a05eb9ac --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/NoisyApproximateSetSfmAggregationDefaultPrecision.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OperatorDependency; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.type.SfmSketchType; +import io.airlift.slice.Slice; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.DEFAULT_PRECISION; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.addHashToSketch; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashDouble; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashLong; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.hashSlice; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.mergeStates; +import static com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils.writeSketch; + +@AggregationFunction(value = "noisy_approx_set_sfm") +public final class NoisyApproximateSetSfmAggregationDefaultPrecision +{ + private NoisyApproximateSetSfmAggregationDefaultPrecision() {} + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") long value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets) + { + addHashToSketch(state, hashLong(methodHandle, value), epsilon, numberOfBuckets, DEFAULT_PRECISION); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") double value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets) + { + addHashToSketch(state, hashDouble(methodHandle, value), epsilon, numberOfBuckets, DEFAULT_PRECISION); + } + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency(operator = XX_HASH_64, argumentTypes = {"T"}) MethodHandle methodHandle, + @AggregationState SfmSketchState state, + @SqlType("T") Slice value, + @SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets) + { + addHashToSketch(state, hashSlice(methodHandle, value), epsilon, numberOfBuckets, DEFAULT_PRECISION); + } + + @CombineFunction + public static void combineState(@AggregationState SfmSketchState state, @AggregationState SfmSketchState otherState) + { + mergeStates(state, otherState); + } + + @OutputFunction(SfmSketchType.NAME) + public static void evaluateFinal(@AggregationState SfmSketchState state, BlockBuilder out) + { + writeSketch(state, out); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchAggregationUtils.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchAggregationUtils.java new file mode 100644 index 0000000000000..331d91c1b1b66 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchAggregationUtils.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import com.facebook.presto.spi.PrestoException; +import io.airlift.slice.Slice; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.util.Failures.checkCondition; +import static com.facebook.presto.util.Failures.internalError; + +public final class SfmSketchAggregationUtils +{ + public static final int DEFAULT_BUCKET_COUNT = 4096; + public static final int DEFAULT_PRECISION = 24; + + private SfmSketchAggregationUtils() {} + + public static void ensureStateInitialized(SfmSketchState state, double epsilon, long numberOfBuckets, long precision) + { + if (state.getSketch() != null) { + return; // already initialized + } + + validateSketchParameters(epsilon, (int) numberOfBuckets, (int) precision); + SfmSketch sketch = SfmSketch.create((int) numberOfBuckets, (int) precision); + state.setSketch(sketch); + state.setEpsilon(epsilon); + } + + public static void validateSketchParameters(double epsilon, int numberOfBuckets, int precision) + { + checkCondition(epsilon > 0, + INVALID_FUNCTION_ARGUMENT, "epsilon must be positive"); + checkCondition(numberOfBuckets > 0 && (numberOfBuckets & (numberOfBuckets - 1)) == 0, + INVALID_FUNCTION_ARGUMENT, "numberOfBuckets must be a power of 2"); + checkCondition(precision > 0 && precision % Byte.SIZE == 0, + INVALID_FUNCTION_ARGUMENT, "precision must be a multiple of %s", Byte.SIZE); + } + + public static void addHashToSketch(SfmSketchState state, long hash, double epsilon, long numberOfBuckets, long precision) + { + ensureStateInitialized(state, epsilon, numberOfBuckets, precision); + state.getSketch().addHash(hash); + } + + public static long hashLong(MethodHandle methodHandle, long value) + { + long hash; + try { + hash = (long) methodHandle.invokeExact(value); + } + catch (Throwable t) { + throw internalError(t); + } + return hash; + } + + public static long hashDouble(MethodHandle methodHandle, double value) + { + long hash; + try { + hash = (long) methodHandle.invokeExact(value); + } + catch (Throwable t) { + throw internalError(t); + } + return hash; + } + + public static long hashSlice(MethodHandle methodHandle, Slice value) + { + long hash; + try { + hash = (long) methodHandle.invokeExact(value); + } + catch (Throwable t) { + throw internalError(t); + } + return hash; + } + + public static void mergeStates(SfmSketchState state, SfmSketchState otherState) + { + SfmSketch sketch = state.getSketch(); + SfmSketch otherSketch = otherState.getSketch(); + if (sketch == null) { + state.setSketch(otherSketch); + state.setEpsilon(otherState.getEpsilon()); + } + else { + try { + // Throws if the sketches are incompatible (e.g., different bucket counts/size) + // Catch and throw a PrestoException + state.mergeSketch(otherSketch); + } + catch (IllegalArgumentException e) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e.getMessage(), e); + } + } + } + + public static void writeCardinality(SfmSketchState state, BlockBuilder out) + { + SfmSketch sketch = state.getSketch(); + if (sketch == null) { + // In the event we process no rows, output NULL. + // Note: Although the SfmSketch is differentially private, this particular output is not. + // We cannot output anything DP here because the state will not include epsilon if we processed no rows. + out.appendNull(); + } + else { + sketch.enablePrivacy(state.getEpsilon()); + BIGINT.writeLong(out, sketch.cardinality()); + } + } + + public static void writeSketch(SfmSketchState state, BlockBuilder out) + { + SfmSketch sketch = state.getSketch(); + if (sketch == null) { + // In the event we process no rows, output NULL. + // Note: Although the SfmSketch is differentially private, this particular output is not. + // We cannot output anything DP here because the state will not include epsilon if we processed no rows. + out.appendNull(); + } + else { + sketch.enablePrivacy(state.getEpsilon()); + VARBINARY.writeSlice(out, sketch.serialize()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchMergeAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchMergeAggregation.java new file mode 100644 index 0000000000000..7199e3aaedb9c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchMergeAggregation.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.type.SfmSketchType; +import io.airlift.slice.Slice; + +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; + +@AggregationFunction("merge") +public final class SfmSketchMergeAggregation +{ + private SfmSketchMergeAggregation() {} + + @InputFunction + public static void input(@AggregationState SfmSketchState state, @SqlType(SfmSketchType.NAME) Slice value) + { + SfmSketch sketch = SfmSketch.deserialize(value); + SfmSketch previous = state.getSketch(); + if (previous == null) { + // if state sketch is empty, add this sketch to the state + state.setSketch(sketch); + state.setEpsilon(0); // not used + } + else { + // if state already has a sketch, merge in the current sketch + try { + // throws if the sketches are incompatible (e.g., different bucket count/size) + state.mergeSketch(sketch); + } + catch (IllegalArgumentException e) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e.getMessage(), e); + } + } + } + + @CombineFunction + public static void combine(@AggregationState SfmSketchState state, @AggregationState SfmSketchState otherState) + { + SfmSketchAggregationUtils.mergeStates(state, otherState); + } + + @OutputFunction(SfmSketchType.NAME) + public static void output(@AggregationState SfmSketchState state, BlockBuilder out) + { + VARBINARY.writeSlice(out, state.getSketch().serialize()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchState.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchState.java new file mode 100644 index 0000000000000..1aa65ceafc1ca --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchState.java @@ -0,0 +1,34 @@ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import com.facebook.presto.spi.function.AccumulatorState; +import com.facebook.presto.spi.function.AccumulatorStateMetadata; + +@AccumulatorStateMetadata(stateSerializerClass = SfmSketchStateSerializer.class, stateFactoryClass = SfmSketchStateFactory.class) +public interface SfmSketchState + extends AccumulatorState +{ + SfmSketch getSketch(); + + void mergeSketch(SfmSketch value); + + void setSketch(SfmSketch value); + + void setEpsilon(double epsilon); + + double getEpsilon(); +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchStateFactory.java new file mode 100644 index 0000000000000..33df1f1a0618a --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchStateFactory.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.array.ObjectBigArray; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import com.facebook.presto.spi.function.AccumulatorStateFactory; +import com.facebook.presto.spi.function.GroupedAccumulatorState; +import org.openjdk.jol.info.ClassLayout; + +import static java.util.Objects.requireNonNull; + +public class SfmSketchStateFactory + implements AccumulatorStateFactory +{ + @Override + public SfmSketchState createSingleState() + { + return new SingleSfmSketchState(); + } + + @Override + public Class getSingleStateClass() + { + return SingleSfmSketchState.class; + } + + @Override + public SfmSketchState createGroupedState() + { + return new GroupedSfmSketchState(); + } + + @Override + public Class getGroupedStateClass() + { + return GroupedSfmSketchState.class; + } + + public static class GroupedSfmSketchState + implements SfmSketchState, GroupedAccumulatorState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedSfmSketchState.class).instanceSize(); + + private final ObjectBigArray sketches = new ObjectBigArray<>(); + private final ObjectBigArray epsilons = new ObjectBigArray<>(); + + private long retainedBytes; + private long groupId; + + @Override + public final void setGroupId(long groupId) + { + this.groupId = groupId; + } + + protected final long getGroupId() + { + return groupId; + } + + @Override + public void ensureCapacity(long size) + { + epsilons.ensureCapacity(size); + sketches.ensureCapacity(size); + } + + @Override + public SfmSketch getSketch() + { + return sketches.get(getGroupId()); + } + + @Override + public void mergeSketch(SfmSketch value) + { + requireNonNull(value, "value is null"); + retainedBytes -= getSketch().getRetainedSizeInBytes(); + getSketch().mergeWith(value); + retainedBytes += value.getRetainedSizeInBytes(); + } + + @Override + public void setSketch(SfmSketch value) + { + requireNonNull(value, "value is null"); + if (getSketch() != null) { + retainedBytes -= getSketch().getRetainedSizeInBytes(); + } + sketches.set(getGroupId(), value); + retainedBytes += value.getRetainedSizeInBytes(); + } + + @Override + public long getEstimatedSize() + { + return INSTANCE_SIZE + retainedBytes + sketches.sizeOf() + epsilons.sizeOf(); + } + + @Override + public void setEpsilon(double value) + { + epsilons.set(getGroupId(), value); + } + + @Override + public double getEpsilon() + { + return epsilons.get(getGroupId()); + } + } + + public static class SingleSfmSketchState + implements SfmSketchState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleSfmSketchState.class).instanceSize(); + + private SfmSketch sketch; + private double epsilon; + + @Override + public SfmSketch getSketch() + { + return sketch; + } + + @Override + public void mergeSketch(SfmSketch value) + { + requireNonNull(value, "value is null"); + sketch.mergeWith(value); + } + + @Override + public void setSketch(SfmSketch value) + { + sketch = value; + } + + @Override + public long getEstimatedSize() + { + long estimatedSize = INSTANCE_SIZE; + if (sketch != null) { + estimatedSize += sketch.getRetainedSizeInBytes(); + } + return estimatedSize; + } + + @Override + public void setEpsilon(double value) + { + epsilon = value; + } + + @Override + public double getEpsilon() + { + return epsilon; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchStateSerializer.java new file mode 100644 index 0000000000000..77be02b46e23d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/SfmSketchStateSerializer.java @@ -0,0 +1,61 @@ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import io.airlift.slice.BasicSliceInput; +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.SizeOf; +import io.airlift.slice.Slice; + +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; + +public class SfmSketchStateSerializer + implements AccumulatorStateSerializer +{ + @Override + public Type getSerializedType() + { + return VARBINARY; + } + + @Override + public void serialize(SfmSketchState state, BlockBuilder out) + { + if (state.getSketch() == null) { + out.appendNull(); + } + else { + DynamicSliceOutput output = new DynamicSliceOutput(state.getSketch().estimatedSerializedSize() + SizeOf.SIZE_OF_DOUBLE); + output.appendDouble(state.getEpsilon()); + state.getSketch().serialize(output); + VARBINARY.writeSlice(out, output.slice()); + } + } + + @Override + public void deserialize(Block block, int index, SfmSketchState state) + { + Slice stateSlice = VARBINARY.getSlice(block, index); + BasicSliceInput input = stateSlice.getInput(); + state.setEpsilon(input.readDouble()); + Slice pcsaSlice = input.slice(); + state.setSketch(SfmSketch.deserialize(pcsaSlice)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/Bitmap.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/Bitmap.java new file mode 100644 index 0000000000000..4ac12600bf453 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/Bitmap.java @@ -0,0 +1,173 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation.sketch; + +import io.airlift.slice.SizeOf; +import io.airlift.slice.SliceInput; +import org.openjdk.jol.info.ClassLayout; + +import java.util.BitSet; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * A level of abstraction over the bitmaps used in sketches such as SFM. + * Abstractly, these are essentially fixed-length arrays of booleans that support flipping and applying randomized response. + * This is implemented as a wrapper around Java's BitSet. + *

+ * Note: The byte arrays in toBytes() and fromSliceInput() are variable-length. + * Trailing zeros are implicitly truncated in these functions. + * The fixed-length nature of the bitmap comes into play in flipAll (randomized response), + * where every bit from 0 to length-1 must be flipped with a fixed probability. + */ +public class Bitmap +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(Bitmap.class).instanceSize(); + private static final int BITSET_INSTANCE_SIZE = ClassLayout.parseClass(BitSet.class).instanceSize(); + + private final BitSet bitSet; + private final int length; + + public Bitmap(int length) + { + checkArgument(length > 0, "length must be positive"); + bitSet = new BitSet(length); + this.length = length; + } + + private Bitmap(int length, BitSet bitSet) + { + requireNonNull(bitSet, "bitSet cannot be null"); + checkArgument(length >= bitSet.length(), "bitmap size must be large enough to cover existing bits"); + this.bitSet = bitSet; + this.length = length; + } + + public static Bitmap fromBytes(int length, byte[] bytes) + { + return new Bitmap(length, BitSet.valueOf(bytes)); + } + + public static Bitmap fromSliceInput(SliceInput input, int byteCount, int length) + { + checkArgument(byteCount >= 0, "byteCount must be nonnegative"); + if (byteCount == 0) { + return new Bitmap(length); + } + + byte[] bytes = new byte[byteCount]; + input.readBytes(bytes); + return Bitmap.fromBytes(length, bytes); + } + + public byte[] toBytes() + { + return bitSet.toByteArray(); + } + + /** + * Length of toBytes() + */ + public int byteLength() + { + // per https://docs.oracle.com/javase/8/docs/api/java/util/BitSet.html#toByteArray-- + return (bitSet.length() + 7) / 8; + } + + @Override + public Bitmap clone() + { + return Bitmap.fromBytes(length, bitSet.toByteArray()); + } + + public long getRetainedSizeInBytes() + { + // Under the hood, BitSet stores a long[] array of BitSet.size() bits + return INSTANCE_SIZE + BITSET_INSTANCE_SIZE + SizeOf.sizeOfLongArray(bitSet.size() / Long.SIZE); + } + + public boolean getBit(int position) + { + return bitSet.get(position); + } + + /** + * The number of 1-bits in the bitmap + */ + public int getBitCount() + { + return bitSet.cardinality(); + } + + /** + * Randomly (and independently) flip all bits with specified probability + */ + public void flipAll(double probability, RandomizationStrategy randomizationStrategy) + { + for (int i = 0; i < length; i++) { + flipBit(i, probability, randomizationStrategy); + } + } + + /** + * Deterministically flips the bit at a given position + */ + public void flipBit(int position) + { + bitSet.flip(position); + } + + /** + * Randomly flips the bit at a given position with specified probability + */ + public void flipBit(int position, double probability, RandomizationStrategy randomizationStrategy) + { + if (randomizationStrategy.nextBoolean(probability)) { + flipBit(position); + } + } + + /** + * The nominal fixed length of the bitmap (actual stored size may vary) + */ + public int length() + { + return length; + } + + /** + * Explicitly set the value of the bit at a given position + */ + public void setBit(int position, boolean value) + { + bitSet.set(position, value); + } + + public void or(Bitmap other) + { + requireNonNull(other, "cannot combine with null Bitmap"); + checkArgument(length() == other.length(), "cannot OR two bitmaps of different size"); + + bitSet.or(other.bitSet); + } + + public void xor(Bitmap other) + { + requireNonNull(other, "cannot combine with null Bitmap"); + checkArgument(length() == other.length(), "cannot XOR two bitmaps of different size"); + + bitSet.xor(other.bitSet); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/RandomizationStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/RandomizationStrategy.java new file mode 100644 index 0000000000000..3a8cd69d43071 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/RandomizationStrategy.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation.sketch; + +public abstract class RandomizationStrategy +{ + public boolean nextBoolean(double probability) + { + return nextDouble() <= probability; + } + + abstract double nextDouble(); +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SecureRandomizationStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SecureRandomizationStrategy.java new file mode 100644 index 0000000000000..ccdc146ff63f4 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SecureRandomizationStrategy.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation.sketch; + +import com.facebook.presto.util.SecureRandomGeneration; + +import java.security.SecureRandom; + +/** + * Note: Due to finite-precision implementation details, usage of floating-point functions + * as random noise, while cryptographically secure, may leak information from a privacy context. + * See "On Significance of the Least Significant Bits For Differential Privacy" by Mironov + * and use judiciously. + */ +public class SecureRandomizationStrategy + extends RandomizationStrategy +{ + private final SecureRandom random; + + public SecureRandomizationStrategy() + { + random = SecureRandomGeneration.getNonBlocking(); + } + + public double nextDouble() + { + return random.nextDouble(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SfmSketch.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SfmSketch.java new file mode 100644 index 0000000000000..f014b740c5885 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SfmSketch.java @@ -0,0 +1,428 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation.sketch; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.airlift.slice.BasicSliceInput; +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Murmur3Hash128; +import io.airlift.slice.SizeOf; +import io.airlift.slice.Slice; +import org.openjdk.jol.info.ClassLayout; + +import javax.annotation.concurrent.NotThreadSafe; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * SfmSketch is a sketch for distinct counting, very similar to HyperLogLog. + * This sketch is introduced as the Sketch-Flip-Merge (SFM) summary in the paper + * Sketch-Flip-Merge: Mergeable Sketches for Private Distinct Counting. + *

+ * The primary differences between SfmSketch and HyperLogLog are that + * (a) SfmSketch supports differential privacy, and + * (b) where HyperLogLog tracks only max observed bucket values, SfmSketch tracks all bucket values observed. + *

+ * This means that SfmSketch is a larger sketch than HyperLogLog, but offers the ability to store completely + * DP sketches with a fixed, public hash function while maintaining accurate cardinality estimates. + *

+ * SfmSketch is created in a non-private mode. Privacy must be enabled through the enablePrivacy() function. + * Once made private, the sketch becomes immutable. Privacy is quantified by the parameter epsilon. + *

+ * When epsilon > 0, the sketch is epsilon-DP, and bits are randomized to preserve privacy. + * When epsilon == NON_PRIVATE_EPSILON, the sketch is not private, and bits are set deterministically. + *

+ * The best accuracy comes with NON_PRIVATE_EPSILON. For private epsilons, larger gives more accuracy, + * while smaller gives more privacy. + */ +@NotThreadSafe +public class SfmSketch +{ + public static final double NON_PRIVATE_EPSILON = Double.POSITIVE_INFINITY; + + private static final byte FORMAT_TAG = 7; + private static final int MAX_ESTIMATION_ITERATIONS = 1000; + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SfmSketch.class).instanceSize(); + + private final int indexBitLength; + private final int precision; + + private double randomizedResponseProbability; + private final Bitmap bitmap; + + private SfmSketch(Bitmap bitmap, int indexBitLength, int precision, double randomizedResponseProbability) + { + requireNonNull(bitmap, "bitmap cannot be null"); + validatePrefixLength(indexBitLength); + validatePrecision(precision, indexBitLength); + validateRandomizedResponseProbability(randomizedResponseProbability); + + this.bitmap = bitmap; + this.indexBitLength = indexBitLength; + this.precision = precision; + this.randomizedResponseProbability = randomizedResponseProbability; + } + + /** + * Create a new SfmSketch in non-private mode. To make private, + * call enablePrivacy() after populating the sketch. + */ + public static SfmSketch create(int numberOfBuckets, int precision) + { + // Only create non-private sketches. + // Private sketches are immutable, so they're kind of useless to create. + double randomizedResponseProbability = getRandomizedResponseProbability(NON_PRIVATE_EPSILON); + int indexBitLength = indexBitLength(numberOfBuckets); + Bitmap bitmap = new Bitmap(numberOfBuckets * precision); + return new SfmSketch(bitmap, indexBitLength, precision, randomizedResponseProbability); + } + + public static SfmSketch deserialize(Slice serialized) + { + // Format: + // format | indexBitLength | precision | epsilon | bitmap + BasicSliceInput input = serialized.getInput(); + byte format = input.readByte(); + checkArgument(format == FORMAT_TAG, "Wrong format tag"); + + int indexBitLength = input.readInt(); + int precision = input.readInt(); + double randomizedResponseProbability = input.readDouble(); + int bitmapByteLength = input.readInt(); + + Bitmap bitmap = Bitmap.fromSliceInput(input, bitmapByteLength, numberOfBuckets(indexBitLength) * precision); + return new SfmSketch(bitmap, indexBitLength, precision, randomizedResponseProbability); + } + + public void add(long value) + { + addHash(Murmur3Hash128.hash64(value)); + } + + public void add(Slice value) + { + addHash(Murmur3Hash128.hash64(value)); + } + + public void addHash(long hash) + { + int index = computeIndex(hash, indexBitLength); + // cap zeros at precision - 1 + // essentially, we're looking at a (precision - 1)-bit hash + int zeros = Math.min(precision - 1, numberOfTrailingZeros(hash, indexBitLength)); + flipBitOn(index, zeros); + } + + /** + * Estimates cardinality via maximum psuedolikelihood (Newton's method) + */ + public long cardinality() + { + // The initial guess of 1 may seem awful, but this converges quickly, and starting small returns better results for small cardinalities. + // This generally takes <= 40 iterations, even for cardinalities as large as 10^33. + double guess = 1; + double changeInGuess = Double.POSITIVE_INFINITY; + int iterations = 0; + while (Math.abs(changeInGuess) > 0.1 && iterations < MAX_ESTIMATION_ITERATIONS) { + changeInGuess = -logLikelihoodFirstDerivative(guess) / logLikelihoodSecondDerivative(guess); + guess += changeInGuess; + iterations += 1; + } + return Math.max(0, Math.round(guess)); + } + + public static int computeIndex(long hash, int indexBitLength) + { + return (int) (hash >>> (Long.SIZE - indexBitLength)); + } + + /** + * Enable privacy on a non-privacy-enabled sketch + *

+ * Per Lemma 4.7, arXiv:2302.02056, + * flipping every bit with probability 1/(e^epsilon + 1) achieves differential privacy. + */ + public void enablePrivacy(double epsilon) + { + enablePrivacy(epsilon, getDefaultRandomizationStrategy()); + } + + public void enablePrivacy(double epsilon, RandomizationStrategy randomizationStrategy) + { + requireNonNull(randomizationStrategy, "randomizationStrategy cannot be null"); + checkArgument(!isPrivacyEnabled(), "sketch is already privacy-enabled"); + validateEpsilon(epsilon); + + randomizedResponseProbability = getRandomizedResponseProbability(epsilon); + + // Flip every bit with fixed probability + bitmap.flipAll(randomizedResponseProbability, randomizationStrategy); + } + + public int estimatedSerializedSize() + { + return SizeOf.SIZE_OF_BYTE + // type + version + SizeOf.SIZE_OF_INT + // indexBitLength + SizeOf.SIZE_OF_INT + // precision + SizeOf.SIZE_OF_DOUBLE + // randomized response probability + SizeOf.SIZE_OF_INT + // bitmap byte length + (bitmap.byteLength() * SizeOf.SIZE_OF_BYTE); // bitmap + } + + private void flipBitOn(int bucket, int level) + { + checkArgument(!isPrivacyEnabled(), "privacy-enabled SfmSketch is immutable"); + + int i = getBitLocation(bucket, level); + bitmap.setBit(i, true); + } + + @VisibleForTesting + int getBitLocation(int bucket, int level) + { + return level * numberOfBuckets(indexBitLength) + bucket; + } + + public Bitmap getBitmap() + { + return bitmap; + } + + private static RandomizationStrategy getDefaultRandomizationStrategy() + { + return new SecureRandomizationStrategy(); + } + + @VisibleForTesting + double getOnProbability() + { + // probability of a 1-bit remaining a 1-bit under randomized response + return 1 - randomizedResponseProbability; + } + + static double getRandomizedResponseProbability(double epsilon) + { + // If non-private, we don't use randomized response. + // Otherwise, flip bits with probability 1/(exp(epsilon) + 1). + if (epsilon == NON_PRIVATE_EPSILON) { + return 0; + } + return 1.0 / (Math.exp(epsilon) + 1); + } + + @VisibleForTesting + double getRandomizedResponseProbability() + { + // probability of a 0-bit flipping to a 1-bit under randomized response + return randomizedResponseProbability; + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + bitmap.getRetainedSizeInBytes(); + } + + public static int indexBitLength(int numberOfBuckets) + { + Preconditions.checkArgument(isPowerOf2(numberOfBuckets), "numberOfBuckets must be a power of 2, actual: %s", numberOfBuckets); + // 2**N has N trailing zeros, and we've asserted numberOfBuckets == 2**N + return Integer.numberOfTrailingZeros(numberOfBuckets); + } + + public static boolean isPowerOf2(long value) + { + Preconditions.checkArgument(value > 0, "value must be positive"); + return (value & (value - 1)) == 0; + } + + public boolean isPrivacyEnabled() + { + return getRandomizedResponseProbability() > 0; + } + + private double logLikelihoodFirstDerivative(double n) + { + // Technically, this is the first derivative of the log of a psuedolikelihood. + double result = 0; + for (int level = 0; level < precision; level++) { + double termOn = logLikelihoodTermFirstDerivative(level, true, n); + double termOff = logLikelihoodTermFirstDerivative(level, false, n); + for (int bucket = 0; bucket < numberOfBuckets(indexBitLength); bucket++) { + result += bitmap.getBit(getBitLocation(bucket, level)) ? termOn : termOff; + } + } + return result; + } + + private double logLikelihoodTermFirstDerivative(int level, boolean on, double n) + { + double p = observationProbability(level); + int sign = on ? -1 : 1; + double c1 = on ? getOnProbability() : 1 - getOnProbability(); + double c2 = getOnProbability() - getRandomizedResponseProbability(); + return Math.log1p(-p) * (1 - c1 / (c1 + sign * c2 * Math.pow(1 - p, n))); + } + + private double logLikelihoodSecondDerivative(double n) + { + // Technically, this is the second derivative of the log of a psuedolikelihood. + double result = 0; + for (int level = 0; level < precision; level++) { + double termOn = logLikelihoodTermSecondDerivative(level, true, n); + double termOff = logLikelihoodTermSecondDerivative(level, false, n); + for (int bucket = 0; bucket < numberOfBuckets(indexBitLength); bucket++) { + result += bitmap.getBit(getBitLocation(bucket, level)) ? termOn : termOff; + } + } + return result; + } + + private double logLikelihoodTermSecondDerivative(int level, boolean on, double n) + { + double p = observationProbability(level); + int sign = on ? -1 : 1; + double c1 = on ? getOnProbability() : 1 - getOnProbability(); + double c2 = getOnProbability() - getRandomizedResponseProbability(); + return sign * c1 * c2 * Math.pow(Math.log1p(-p), 2) * Math.pow(1 - p, n) * Math.pow(c1 + sign * c2 * Math.pow(1 - p, n), -2); + } + + /** + * Merging two sketches with randomizedResponseProbability values p1 and p2 is equivalent to + * having created two non-private sketches, merged them, then enabled privacy with a + * randomizedResponseProbability value of: + *

+ * (p1 + p2 - 3 * p1 * p2) / (1 - 2 * p1 * p2) + *

+ * This can be derived from the fact that two private sketches created with epsilon1 and epsilon2 + * merge to be equivalent to a single sketch created with epsilon: + *

+ * -log(exp(-epsilon1) + exp(-epsilon2) - exp(-(epsilon1 + epsilon2)) + *

+ * For details, see Theorem 4.8, arXiv:2302.02056. + * For verification, see the unit tests. + */ + @VisibleForTesting + static double mergeRandomizedResponseProbabilities(double p1, double p2) + { + return (p1 + p2 - 3 * p1 * p2) / (1 - 2 * p1 * p2); + } + + /** + * Performs a merge of the other sketch into the current sketch. This is performed + * as a randomized merge as described in Theorem 4.8, + * arXiv:2302.02056. + *

+ * The formula used in this function is a simplification of the form presented in the original paper. + * See also Section 3, arXiv:2306.09394. + */ + public void mergeWith(SfmSketch other) + { + mergeWith(other, getDefaultRandomizationStrategy()); + } + + public void mergeWith(SfmSketch other, RandomizationStrategy randomizationStrategy) + { + requireNonNull(randomizationStrategy, "randomizationStrategy cannot be null"); + + // Strictly speaking, we may be able to provide more general merging than suggested here. + // It's not clear how useful this would be in practice. + checkArgument(precision == other.precision, "cannot merge two SFM sketches with different precision: %s vs. %s", precision, other.precision); + checkArgument(indexBitLength == other.indexBitLength, "cannot merge two SFM sketches with different indexBitLength: %s vs. %s", + indexBitLength, other.indexBitLength); + + if (!isPrivacyEnabled() && !other.isPrivacyEnabled()) { + // if neither sketch is private, we just take the OR of the sketches + bitmap.or(other.getBitmap()); + } + else { + // if either sketch is private, we combine using a randomized merge + // (the non-private case above is a special case of this more complicated math) + double p1 = randomizedResponseProbability; + double p2 = other.randomizedResponseProbability; + double p = mergeRandomizedResponseProbabilities(p1, p2); + double normalizer = (1 - 2 * p) / ((1 - 2 * p1) * (1 - 2 * p2)); + + for (int i = 0; i < bitmap.length(); i++) { + double bit1 = bitmap.getBit(i) ? 1 : 0; + double bit2 = other.bitmap.getBit(i) ? 1 : 0; + double x = 1 - 2 * p - normalizer * (1 - p1 - bit1) * (1 - p2 - bit2); + double probability = p + normalizer * x; + probability = Math.min(1.0, Math.max(0.0, probability)); + bitmap.setBit(i, randomizationStrategy.nextBoolean(probability)); + } + } + + randomizedResponseProbability = mergeRandomizedResponseProbabilities(randomizedResponseProbability, other.randomizedResponseProbability); + } + + public static int numberOfBuckets(int indexBitLength) + { + return 1 << indexBitLength; + } + + public static int numberOfTrailingZeros(long hash, int indexBitLength) + { + long value = hash | (1L << (Long.SIZE - indexBitLength)); // place a 1 in the final position of the prefix to avoid flowing into prefix when the hash happens to be 0 + return Long.numberOfTrailingZeros(value); + } + + private double observationProbability(int level) + { + // probability of observing a run of zeros of length level in any single bucket + // note: this is NOT (in general) the probability of having a 1 in the corresponding location in the sketch + // (it is if bits are set deterministically, as when epsilon < 0) + return Math.pow(2.0, -(level + 1)) / numberOfBuckets(indexBitLength); + } + + public Slice serialize() + { + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(estimatedSerializedSize()); + serialize(sliceOutput); + return sliceOutput.slice(); + } + + public void serialize(DynamicSliceOutput sliceOutput) + { + byte[] bitmapBytes = bitmap.toBytes(); + sliceOutput.appendByte(FORMAT_TAG) + .appendInt(indexBitLength) + .appendInt(precision) + .appendDouble(randomizedResponseProbability) + .appendInt(bitmapBytes.length) + .appendBytes(bitmapBytes); + } + + private static void validateEpsilon(double epsilon) + { + checkArgument(epsilon > 0, "epsilon must be greater than zero or equal to NON_PRIVATE_EPSILON"); + } + + private static void validatePrecision(int precision, int indexBitLength) + { + checkArgument(precision > 0, "precision must be positive", Byte.SIZE); + checkArgument(precision + indexBitLength <= Long.SIZE, "precision + indexBitLength cannot exceed %s", Long.SIZE); + } + + private static void validatePrefixLength(int indexBitLength) + { + checkArgument(indexBitLength >= 1 && indexBitLength <= 32, "indexBitLength is out of range"); + } + + private static void validateRandomizedResponseProbability(double p) + { + checkArgument(p >= 0 && p <= 0.5, "randomizedResponseProbability should be in the interval [0, 0.5]"); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index e05e0112f216e..6dde8ca21e66c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -28,6 +28,7 @@ import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.type.Constraint; import com.facebook.presto.type.LiteralParameter; +import com.facebook.presto.util.SecureRandomGeneration; import com.google.common.primitives.Doubles; import io.airlift.slice.Slice; import org.apache.commons.math3.distribution.BetaDistribution; @@ -43,7 +44,6 @@ import java.math.BigDecimal; import java.math.BigInteger; -import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.util.concurrent.ThreadLocalRandom; @@ -61,7 +61,6 @@ import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.unscaledDecimalToUnscaledLong; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static com.facebook.presto.spi.function.FunctionKind.SCALAR; import static com.facebook.presto.type.DecimalOperators.modulusScalarFunction; @@ -99,13 +98,6 @@ public final class MathFunctions } } - private static final String SECURE_RANDOM_ALGORITHM; - - static { - String os = System.getProperty("os.name"); - SECURE_RANDOM_ALGORITHM = os.startsWith("Windows") ? "SHA1PRNG" : "NativePRNGNonBlocking"; - } - private MathFunctions() {} @Description("absolute value") @@ -703,13 +695,8 @@ public static long random(@SqlType(StandardTypes.BIGINT) long value) @SqlType(StandardTypes.DOUBLE) public static double secure_random() { - try { - SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); - return random.nextDouble(); - } - catch (NoSuchAlgorithmException e) { - throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); - } + SecureRandom random = SecureRandomGeneration.getNonBlocking(); + return random.nextDouble(); } @Description("a cryptographically secure random number between lower and upper (exclusive)") @@ -718,15 +705,10 @@ public static double secure_random() public static double secure_random(@SqlType(StandardTypes.DOUBLE) double lower, @SqlType(StandardTypes.DOUBLE) double upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); - try { - SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); - return random.doubles(lower, upper) - .findFirst() - .getAsDouble(); - } - catch (NoSuchAlgorithmException e) { - throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); - } + SecureRandom random = SecureRandomGeneration.getNonBlocking(); + return random.doubles(lower, upper) + .findFirst() + .getAsDouble(); } @Description("a cryptographically secure random number between lower and upper (exclusive)") @@ -735,15 +717,10 @@ public static double secure_random(@SqlType(StandardTypes.DOUBLE) double lower, public static long secureRandomTinyint(@SqlType(StandardTypes.TINYINT) long lower, @SqlType(StandardTypes.TINYINT) long upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); - try { - SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); - return random.ints((int) lower, (int) upper) - .findFirst() - .getAsInt(); - } - catch (NoSuchAlgorithmException e) { - throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); - } + SecureRandom random = SecureRandomGeneration.getNonBlocking(); + return random.ints((int) lower, (int) upper) + .findFirst() + .getAsInt(); } @Description("a cryptographically secure random number between lower and upper (exclusive)") @@ -752,15 +729,10 @@ public static long secureRandomTinyint(@SqlType(StandardTypes.TINYINT) long lowe public static long secureRandomSmallint(@SqlType(StandardTypes.SMALLINT) long lower, @SqlType(StandardTypes.SMALLINT) long upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); - try { - SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); - return random.ints((int) lower, (int) upper) - .findFirst() - .getAsInt(); - } - catch (NoSuchAlgorithmException e) { - throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); - } + SecureRandom random = SecureRandomGeneration.getNonBlocking(); + return random.ints((int) lower, (int) upper) + .findFirst() + .getAsInt(); } @Description("a cryptographically secure random number between lower and upper (exclusive)") @@ -769,15 +741,10 @@ public static long secureRandomSmallint(@SqlType(StandardTypes.SMALLINT) long lo public static long secureRandomInteger(@SqlType(StandardTypes.INTEGER) long lower, @SqlType(StandardTypes.INTEGER) long upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); - try { - SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); - return random.ints((int) lower, (int) upper) - .findFirst() - .getAsInt(); - } - catch (NoSuchAlgorithmException e) { - throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); - } + SecureRandom random = SecureRandomGeneration.getNonBlocking(); + return random.ints((int) lower, (int) upper) + .findFirst() + .getAsInt(); } @Description("a cryptographically secure random number between lower and upper (exclusive)") @@ -786,15 +753,10 @@ public static long secureRandomInteger(@SqlType(StandardTypes.INTEGER) long lowe public static long secureRandomBigint(@SqlType(StandardTypes.BIGINT) long lower, @SqlType(StandardTypes.BIGINT) long upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); - try { - SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); - return random.longs(lower, upper) - .findFirst() - .getAsLong(); - } - catch (NoSuchAlgorithmException e) { - throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); - } + SecureRandom random = SecureRandomGeneration.getNonBlocking(); + return random.longs(lower, upper) + .findFirst() + .getAsLong(); } @Description("Inverse of normal cdf given a mean, std, and probability") diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/SfmSketchFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/SfmSketchFunctions.java new file mode 100644 index 0000000000000..d06eddb3f449c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/SfmSketchFunctions.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchAggregationUtils; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.type.SfmSketchType; +import io.airlift.slice.Slice; + +public final class SfmSketchFunctions +{ + private SfmSketchFunctions() {} + + @ScalarFunction + @Description("estimated cardinality of an SfmSketch object") + @SqlType(StandardTypes.BIGINT) + public static long cardinality(@SqlType(SfmSketchType.NAME) Slice serializedSketch) + { + return SfmSketch.deserialize(serializedSketch).cardinality(); + } + + @ScalarFunction(value = "merge_sfm", deterministic = false) + @Description("merge the contents of an array of SfmSketch objects") + @SqlType(SfmSketchType.NAME) + @SqlNullable + public static Slice scalarMerge(@SqlType("array(SfmSketch)") Block block) + { + if (block.getPositionCount() == 0) { + return null; + } + + SfmSketch merged = null; + for (int i = 0; i < block.getPositionCount(); i++) { + if (block.isNull(i)) { + continue; + } + SfmSketch sketch = SfmSketch.deserialize(block.getSlice(i, 0, block.getSliceLength(i))); + if (merged == null) { + merged = sketch; + } + else { + merged.mergeWith(sketch); + } + } + + if (merged == null) { + return null; + } + + return merged.serialize(); + } + + @ScalarFunction(value = "noisy_empty_approx_set_sfm", deterministic = false) + @Description("an SfmSketch object representing an empty set") + @SqlType(SfmSketchType.NAME) + public static Slice emptyApproxSet(@SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets, + @SqlType(StandardTypes.BIGINT) long precision) + { + SfmSketchAggregationUtils.validateSketchParameters(epsilon, (int) numberOfBuckets, (int) precision); + SfmSketch sketch = SfmSketch.create((int) numberOfBuckets, (int) precision); + sketch.enablePrivacy(epsilon); + return sketch.serialize(); + } + + @ScalarFunction(value = "noisy_empty_approx_set_sfm", deterministic = false) + @Description("an SfmSketch object representing an empty set") + @SqlType(SfmSketchType.NAME) + public static Slice emptyApproxSet(@SqlType(StandardTypes.DOUBLE) double epsilon, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets) + { + return emptyApproxSet(epsilon, numberOfBuckets, SfmSketchAggregationUtils.DEFAULT_PRECISION); + } + + @ScalarFunction(value = "noisy_empty_approx_set_sfm", deterministic = false) + @Description("an SfmSketch object representing an empty set") + @SqlType(SfmSketchType.NAME) + public static Slice emptyApproxSet(@SqlType(StandardTypes.DOUBLE) double epsilon) + { + return emptyApproxSet(epsilon, SfmSketchAggregationUtils.DEFAULT_BUCKET_COUNT, SfmSketchAggregationUtils.DEFAULT_PRECISION); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/type/SfmSketchOperators.java b/presto-main/src/main/java/com/facebook/presto/type/SfmSketchOperators.java new file mode 100644 index 0000000000000..895b0e1fa0a17 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/type/SfmSketchOperators.java @@ -0,0 +1,43 @@ +package com.facebook.presto.type; + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.ScalarOperator; +import com.facebook.presto.spi.function.SqlType; +import io.airlift.slice.Slice; + +import static com.facebook.presto.common.function.OperatorType.CAST; + +public final class SfmSketchOperators +{ + private SfmSketchOperators() + { + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARBINARY) + public static Slice castToVarbinary(@SqlType(SfmSketchType.NAME) Slice slice) + { + return slice; + } + + @ScalarOperator(CAST) + @SqlType(SfmSketchType.NAME) + public static Slice castFromVarbinary(@SqlType(StandardTypes.VARBINARY) Slice slice) + { + return slice; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/type/SfmSketchType.java b/presto-main/src/main/java/com/facebook/presto/type/SfmSketchType.java new file mode 100644 index 0000000000000..512f91c67d274 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/type/SfmSketchType.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.type; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.function.SqlFunctionProperties; +import com.facebook.presto.common.type.AbstractVariableWidthType; +import com.facebook.presto.common.type.SqlVarbinary; +import com.fasterxml.jackson.annotation.JsonCreator; +import io.airlift.slice.Slice; + +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; + +public class SfmSketchType + extends AbstractVariableWidthType +{ + public static final SfmSketchType SFM_SKETCH = new SfmSketchType(); + public static final String NAME = "SfmSketch"; + + @JsonCreator + public SfmSketchType() + { + super(parseTypeSignature(NAME), Slice.class); + } + + @Override + public void appendTo(Block block, int position, BlockBuilder blockBuilder) + { + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + block.writeBytesTo(position, 0, block.getSliceLength(position), blockBuilder); + blockBuilder.closeEntry(); + } + } + + @Override + public Slice getSlice(Block block, int position) + { + return block.getSlice(position, 0, block.getSliceLength(position)); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value) + { + writeSlice(blockBuilder, value, 0, value.length()); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value, int offset, int length) + { + blockBuilder.writeBytes(value, offset, length).closeEntry(); + } + + @Override + public Object getObjectValue(SqlFunctionProperties properties, Block block, int position) + { + if (block.isNull(position)) { + return null; + } + + return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/util/SecureRandomGeneration.java b/presto-main/src/main/java/com/facebook/presto/util/SecureRandomGeneration.java new file mode 100644 index 0000000000000..e13150ed27392 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/util/SecureRandomGeneration.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.util; + +import com.facebook.presto.spi.PrestoException; + +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; + +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; + +public final class SecureRandomGeneration +{ + private SecureRandomGeneration() {} + + public static String getNonBlockingAlgorithmName() + { + String os = System.getProperty("os.name"); + return os.startsWith("Windows") ? "SHA1PRNG" : "NativePRNGNonBlocking"; + } + + /** + * Return a non-blocking instance of SecureRandom, or throw PrestoException if not supported + *

+ * With the exception of Windows machines, this uses the NativePRNGNonBlocking algorithm. + * On Windows, this uses SHA1PRNG. + */ + public static SecureRandom getNonBlocking() + { + String algorithm = getNonBlockingAlgorithmName(); + try { + return SecureRandom.getInstance(algorithm); + } + catch (NoSuchAlgorithmException e) { + throw new PrestoException(NOT_SUPPORTED, algorithm + " is not supported in your OS", e); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/AbstractTestNoisySfmAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/AbstractTestNoisySfmAggregation.java new file mode 100644 index 0000000000000..5de7af6c1b7a5 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/AbstractTestNoisySfmAggregation.java @@ -0,0 +1,179 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.SqlVarbinary; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation; +import com.facebook.presto.type.DoubleOperators; +import com.facebook.presto.type.IntegerOperators; +import com.facebook.presto.type.VarcharOperators; +import io.airlift.slice.Slices; +import org.apache.commons.math3.util.Precision; + +import java.util.function.BiFunction; + +import static com.facebook.presto.block.BlockAssertions.createDoubleRepeatBlock; +import static com.facebook.presto.block.BlockAssertions.createLongRepeatBlock; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; + +/** + * Parent class for testing noisy_approx_set_sfm and noisy_approx_distinct_sfm. + * The tests will essentially be the same, but since noisy_approx_set_sfm returns the sketch object itself, + * we will map this to a cardinality, giving us something equivalent to noisy_approx_distinct_sfm. + */ +abstract class AbstractTestNoisySfmAggregation +{ + private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager(); + + protected abstract String getFunctionName(); + + protected abstract long getCardinalityFromResult(Object result); + + protected SfmSketch getSketchFromResult(Object result) + { + return SfmSketch.deserialize(Slices.wrappedBuffer(((SqlVarbinary) result).getBytes())); + } + + protected static JavaAggregationFunctionImplementation getAggregator(String functionName, Type... type) + { + return FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation( + FUNCTION_AND_TYPE_MANAGER.lookupFunction(functionName, fromTypes(type))); + } + + protected boolean equalCardinalityWithAbsoluteError(Object actual, Object expected, double delta) + { + long actualCardinality = getCardinalityFromResult(actual); + long expectedCardinality = getCardinalityFromResult(expected); + return Precision.equals(actualCardinality, expectedCardinality, delta); + } + + /** + * Run assertion on function with signature F(value, epsilon, numberOfBuckets, precision) + */ + protected void assertFunction(Block valuesBlock, Type valueType, double epsilon, int numberOfBuckets, int precision, BiFunction assertion, Object expected) + { + assertAggregation( + getAggregator(getFunctionName(), valueType, DOUBLE, BIGINT, BIGINT), + assertion, + null, + new Page( + valuesBlock, + createDoubleRepeatBlock(epsilon, valuesBlock.getPositionCount()), + createLongRepeatBlock(numberOfBuckets, valuesBlock.getPositionCount()), + createLongRepeatBlock(precision, valuesBlock.getPositionCount())), + expected); + } + + /** + * Run assertion on function with signature F(value, epsilon, numberOfBuckets) + */ + protected void assertFunction(Block valuesBlock, Type valueType, double epsilon, int numberOfBuckets, BiFunction assertion, Object expected) + { + assertAggregation( + getAggregator(getFunctionName(), valueType, DOUBLE, BIGINT), + assertion, + null, + new Page( + valuesBlock, + createDoubleRepeatBlock(epsilon, valuesBlock.getPositionCount()), + createLongRepeatBlock(numberOfBuckets, valuesBlock.getPositionCount())), + expected); + } + + /** + * Run assertion on function with signature F(value, epsilon) + */ + protected void assertFunction(Block valuesBlock, Type valueType, double epsilon, BiFunction assertion, Object expected) + { + assertAggregation( + getAggregator(getFunctionName(), valueType, DOUBLE), + assertion, + null, + new Page( + valuesBlock, + createDoubleRepeatBlock(epsilon, valuesBlock.getPositionCount())), + expected); + } + + /** + * Assert (approximate) cardinality match on function with signature F(value, epsilon, numberOfBuckets, precision) + */ + protected void assertCardinality(Block valuesBlock, Type valueType, double epsilon, int numberOfBuckets, int precision, Object expected, long delta) + { + assertFunction(valuesBlock, valueType, epsilon, numberOfBuckets, precision, + (actualValue, expectedValue) -> equalCardinalityWithAbsoluteError(actualValue, expectedValue, delta), + expected); + } + + /** + * Assert (approximate) cardinality match on function with signature F(value, epsilon, numberOfBuckets) + */ + protected void assertCardinality(Block valuesBlock, Type valueType, double epsilon, int numberOfBuckets, Object expected, long delta) + { + assertFunction(valuesBlock, valueType, epsilon, numberOfBuckets, + (actualValue, expectedValue) -> equalCardinalityWithAbsoluteError(actualValue, expectedValue, delta), + expected); + } + + /** + * Assert (approximate) cardinality match on function with signature F(value, epsilon) + */ + protected void assertCardinality(Block valuesBlock, Type valueType, double epsilon, Object expected, long delta) + { + assertFunction(valuesBlock, valueType, epsilon, + (actualValue, expectedValue) -> equalCardinalityWithAbsoluteError(actualValue, expectedValue, delta), + expected); + } + + protected SfmSketch createLongSketch(int numberOfBuckets, int precision, int start, int end) + { + SfmSketch sketch = SfmSketch.create(numberOfBuckets, precision); + for (int i = start; i < end; i++) { + sketch.addHash(IntegerOperators.xxHash64(i)); + } + return sketch; + } + + protected SfmSketch createDoubleSketch(int numberOfBuckets, int precision, int start, int end) + { + SfmSketch sketch = SfmSketch.create(numberOfBuckets, precision); + for (int i = start; i < end; i++) { + sketch.addHash(DoubleOperators.xxHash64(i)); + } + return sketch; + } + + protected SfmSketch createStringSketch(int numberOfBuckets, int precision, int start, int end) + { + SfmSketch sketch = SfmSketch.create(numberOfBuckets, precision); + for (int i = start; i < end; i++) { + sketch.addHash(VarcharOperators.xxHash64(Slices.utf8Slice(Long.toString(i)))); + } + return sketch; + } + + protected SqlVarbinary toSqlVarbinary(SfmSketch sketch) + { + return new SqlVarbinary(sketch.serialize().getBytes()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestMergeSfmSketchAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestMergeSfmSketchAggregation.java new file mode 100644 index 0000000000000..abade9ecdd17f --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestMergeSfmSketchAggregation.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import org.testng.annotations.Test; + +import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; +import static com.facebook.presto.type.SfmSketchType.SFM_SKETCH; + +public class TestMergeSfmSketchAggregation + extends AbstractTestNoisySfmAggregation +{ + protected String getFunctionName() + { + return "merge"; + } + + protected long getCardinalityFromResult(Object result) + { + return getSketchFromResult(result).cardinality(); + } + + private Block[] buildBlocks(SfmSketch... sketches) + { + BlockBuilder builder = SFM_SKETCH.createBlockBuilder(null, 3); + for (SfmSketch sketch : sketches) { + SFM_SKETCH.writeSlice(builder, sketch.serialize()); + } + return new Block[] {builder.build()}; + } + + private void assertMergedCardinality(SfmSketch[] indivdualSketches, SfmSketch mergedSketch, double delta) + { + assertAggregation( + getAggregator(getFunctionName(), SFM_SKETCH), + (actualValue, expectedValue) -> equalCardinalityWithAbsoluteError(actualValue, expectedValue, delta), + null, + new Page(buildBlocks(indivdualSketches)), + toSqlVarbinary(mergedSketch)); + } + + @Test + public void testMergeOneNonPrivate() + { + SfmSketch sketch = createLongSketch(4096, 24, 1, 100_000); + + // deterministic test/no noise (no delta needed) + assertMergedCardinality(new SfmSketch[] {sketch}, sketch, 0); + } + + @Test + public void testMergeOnePrivate() + { + SfmSketch sketch = createLongSketch(4096, 24, 1, 100_000); + sketch.enablePrivacy(4); + + // although there is random noise in the sketch, the merge is a no-op, so no delta is needed + assertMergedCardinality(new SfmSketch[] {sketch}, sketch, 0); + } + + @Test + public void testMergeManyNonPrivate() + { + SfmSketch sketch1 = createLongSketch(4096, 24, 1, 100_000); + SfmSketch sketch2 = createLongSketch(4096, 24, 50_000, 200_000); + SfmSketch sketch3 = createLongSketch(4096, 24, 190_000, 210_000); + SfmSketch mergedSketch = createLongSketch(4096, 24, 1, 210_000); + + // deterministic test/no noise (no delta needed) + assertMergedCardinality(new SfmSketch[] {sketch1, sketch2, sketch3}, mergedSketch, 0); + } + + @Test + public void testMergeManyPrivate() + { + SfmSketch sketch1 = createLongSketch(4096, 24, 1, 100_000); + SfmSketch sketch2 = createLongSketch(4096, 24, 50_000, 200_000); + SfmSketch sketch3 = createLongSketch(4096, 24, 190_000, 210_000); + SfmSketch mergedSketch = createLongSketch(4096, 24, 1, 210_000); + sketch1.enablePrivacy(10); + sketch2.enablePrivacy(11); + sketch3.enablePrivacy(12); + + // there is randomness in this merge, so the cardinality should only match approximately + assertMergedCardinality(new SfmSketch[] {sketch1, sketch2, sketch3}, mergedSketch, 50_000); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisyApproximateCountDistinctSfmAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisyApproximateCountDistinctSfmAggregation.java new file mode 100644 index 0000000000000..bff85dab9e45c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisyApproximateCountDistinctSfmAggregation.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import org.testng.annotations.Test; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createStringSequenceBlock; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; + +public class TestNoisyApproximateCountDistinctSfmAggregation + extends AbstractTestNoisySfmAggregation +{ + protected String getFunctionName() + { + return "noisy_approx_distinct_sfm"; + } + + protected long getCardinalityFromResult(Object result) + { + return new Long(result.toString()); + } + + @Test + public void testNonPrivateIntegerCount() + { + Block valuesBlock = createLongSequenceBlock(1, 100_000); + // These estimates are deterministic (no privacy). + assertCardinality(valuesBlock, BIGINT, SfmSketch.NON_PRIVATE_EPSILON, 99_466, 0); + assertCardinality(valuesBlock, BIGINT, SfmSketch.NON_PRIVATE_EPSILON, 8192, 100_219, 0); + assertCardinality(valuesBlock, BIGINT, SfmSketch.NON_PRIVATE_EPSILON, 2048, 32, 100_102, 0); + } + + @Test + public void testPrivateIntegerCount() + { + Block valuesBlock = createLongSequenceBlock(1, 100_000); + // These estimates are random, but not too noisy. + assertCardinality(valuesBlock, BIGINT, 8, 100_000, 25_000); + assertCardinality(valuesBlock, BIGINT, 8, 8192, 100_000, 25_000); + assertCardinality(valuesBlock, BIGINT, 8, 2048, 32, 100_000, 25_000); + } + + @Test + public void testNonPrivateDoubleCount() + { + Block valuesBlock = createDoubleSequenceBlock(1, 100_000); + // These estimates are deterministic (no privacy). + assertCardinality(valuesBlock, DOUBLE, SfmSketch.NON_PRIVATE_EPSILON, 99_670, 0); + assertCardinality(valuesBlock, DOUBLE, SfmSketch.NON_PRIVATE_EPSILON, 8192, 100_078, 0); + assertCardinality(valuesBlock, DOUBLE, SfmSketch.NON_PRIVATE_EPSILON, 2048, 32, 98_350, 0); + } + + @Test + public void testPrivateDoubleCount() + { + Block valuesBlock = createDoubleSequenceBlock(1, 100_000); + // These estimates are random, but not too noisy. + assertCardinality(valuesBlock, DOUBLE, 8, 100_000, 25_000); + assertCardinality(valuesBlock, DOUBLE, 8, 8192, 100_000, 25_000); + assertCardinality(valuesBlock, DOUBLE, 8, 2048, 32, 100_000, 25_000); + } + + @Test + public void testNonPrivateStringCount() + { + Block valuesBlock = createStringSequenceBlock(1, 100_000); + // These estimates are deterministic. + assertCardinality(valuesBlock, VARCHAR, SfmSketch.NON_PRIVATE_EPSILON, 100_190, 0); + assertCardinality(valuesBlock, VARCHAR, SfmSketch.NON_PRIVATE_EPSILON, 8192, 99_982, 0); + assertCardinality(valuesBlock, VARCHAR, SfmSketch.NON_PRIVATE_EPSILON, 2048, 32, 100_773, 0); + } + + @Test + public void testPrivateStringCount() + { + Block valuesBlock = createStringSequenceBlock(1, 100_000); + // These estimates are random, but not too noisy. + assertCardinality(valuesBlock, VARCHAR, 8, 100_000, 25_000); + assertCardinality(valuesBlock, VARCHAR, 8, 8192, 100_000, 25_000); + assertCardinality(valuesBlock, VARCHAR, 8, 2048, 32, 100_000, 25_000); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisyApproximateSetSfmAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisyApproximateSetSfmAggregation.java new file mode 100644 index 0000000000000..42d6ee2cf2ead --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisyApproximateSetSfmAggregation.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.SqlVarbinary; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import org.testng.annotations.Test; + +import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; +import static com.facebook.presto.block.BlockAssertions.createStringSequenceBlock; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; + +/** + * Tests for the noisy_approx_set_sfm function. + * Overall, these are similar to the tests of noisy_approx_distinct_sfm, but with an extra check + * to ensure that the size of the returned sketch matches the parameters specified (or defaulted). + */ +public class TestNoisyApproximateSetSfmAggregation + extends AbstractTestNoisySfmAggregation +{ + protected String getFunctionName() + { + return "noisy_approx_set_sfm"; + } + + protected long getCardinalityFromResult(Object result) + { + return getSketchFromResult(result).cardinality(); + } + + private boolean sketchSizesMatch(Object a, Object b) + { + SfmSketch sketchA = getSketchFromResult(a); + SfmSketch sketchB = getSketchFromResult(b); + return sketchA.getBitmap().length() == sketchB.getBitmap().length(); + } + + private void assertSketchSize(Block valuesBlock, Type valueType, double epsilon, int numberOfBuckets, int precision, SqlVarbinary expected) + { + assertFunction(valuesBlock, valueType, epsilon, numberOfBuckets, precision, this::sketchSizesMatch, expected); + } + + private void assertSketchSize(Block valuesBlock, Type valueType, double epsilon, int numberOfBuckets, SqlVarbinary expected) + { + assertFunction(valuesBlock, valueType, epsilon, numberOfBuckets, this::sketchSizesMatch, expected); + } + + private void assertSketchSize(Block valuesBlock, Type valueType, double epsilon, SqlVarbinary expected) + { + assertFunction(valuesBlock, valueType, epsilon, this::sketchSizesMatch, expected); + } + + @Test + public void testNonPrivateInteger() + { + Block valuesBlock = createLongSequenceBlock(1, 100_000); + + SqlVarbinary refSketch = toSqlVarbinary(createLongSketch(4096, 24, 1, 100_000)); + assertCardinality(valuesBlock, BIGINT, SfmSketch.NON_PRIVATE_EPSILON, refSketch, 0); + assertSketchSize(valuesBlock, BIGINT, SfmSketch.NON_PRIVATE_EPSILON, refSketch); + + refSketch = toSqlVarbinary(createLongSketch(8192, 24, 1, 100_000)); + assertCardinality(valuesBlock, BIGINT, SfmSketch.NON_PRIVATE_EPSILON, 8192, refSketch, 0); + assertSketchSize(valuesBlock, BIGINT, SfmSketch.NON_PRIVATE_EPSILON, 8192, refSketch); + + refSketch = toSqlVarbinary(createLongSketch(2048, 32, 1, 100_000)); + assertCardinality(valuesBlock, BIGINT, SfmSketch.NON_PRIVATE_EPSILON, 2048, 32, refSketch, 0); + assertSketchSize(valuesBlock, BIGINT, SfmSketch.NON_PRIVATE_EPSILON, 2048, 32, refSketch); + } + + @Test + public void testPrivateInteger() + { + Block valuesBlock = createLongSequenceBlock(1, 100_000); + + SqlVarbinary refSketch = toSqlVarbinary(createLongSketch(4096, 24, 1, 100_000)); + assertCardinality(valuesBlock, BIGINT, 8, refSketch, 50_000); + assertSketchSize(valuesBlock, BIGINT, 8, refSketch); + + refSketch = toSqlVarbinary(createLongSketch(8192, 24, 1, 100_000)); + assertCardinality(valuesBlock, BIGINT, 8, 8192, refSketch, 50_000); + assertSketchSize(valuesBlock, BIGINT, 8, 8192, refSketch); + + refSketch = toSqlVarbinary(createLongSketch(2048, 32, 1, 100_000)); + assertCardinality(valuesBlock, BIGINT, 8, 2048, 32, refSketch, 50_000); + assertSketchSize(valuesBlock, BIGINT, 8, 2048, 32, refSketch); + } + + @Test + public void testNonPrivateDouble() + { + Block valuesBlock = createDoubleSequenceBlock(1, 100_000); + + SqlVarbinary refSketch = toSqlVarbinary(createDoubleSketch(4096, 24, 1, 100_000)); + assertCardinality(valuesBlock, DOUBLE, SfmSketch.NON_PRIVATE_EPSILON, refSketch, 0); + assertSketchSize(valuesBlock, DOUBLE, SfmSketch.NON_PRIVATE_EPSILON, refSketch); + + refSketch = toSqlVarbinary(createDoubleSketch(8192, 24, 1, 100_000)); + assertCardinality(valuesBlock, DOUBLE, SfmSketch.NON_PRIVATE_EPSILON, 8192, refSketch, 0); + assertSketchSize(valuesBlock, DOUBLE, SfmSketch.NON_PRIVATE_EPSILON, 8192, refSketch); + + refSketch = toSqlVarbinary(createDoubleSketch(2048, 32, 1, 100_000)); + assertCardinality(valuesBlock, DOUBLE, SfmSketch.NON_PRIVATE_EPSILON, 2048, 32, refSketch, 0); + assertSketchSize(valuesBlock, DOUBLE, SfmSketch.NON_PRIVATE_EPSILON, 2048, 32, refSketch); + } + + @Test + public void testPrivateDouble() + { + Block valuesBlock = createDoubleSequenceBlock(1, 100_000); + + SqlVarbinary refSketch = toSqlVarbinary(createDoubleSketch(4096, 24, 1, 100_000)); + assertCardinality(valuesBlock, DOUBLE, 8, refSketch, 50_000); + assertSketchSize(valuesBlock, DOUBLE, 8, refSketch); + + refSketch = toSqlVarbinary(createDoubleSketch(8192, 24, 1, 100_000)); + assertCardinality(valuesBlock, DOUBLE, 8, 8192, refSketch, 50_000); + assertSketchSize(valuesBlock, DOUBLE, 8, 8192, refSketch); + + refSketch = toSqlVarbinary(createDoubleSketch(2048, 32, 1, 100_000)); + assertCardinality(valuesBlock, DOUBLE, 8, 2048, 32, refSketch, 50_000); + assertSketchSize(valuesBlock, DOUBLE, 8, 2048, 32, refSketch); + } + + @Test + public void testNonPrivateString() + { + Block valuesBlock = createStringSequenceBlock(1, 100_000); + + SqlVarbinary refSketch = toSqlVarbinary(createStringSketch(4096, 24, 1, 100_000)); + assertCardinality(valuesBlock, VARCHAR, SfmSketch.NON_PRIVATE_EPSILON, refSketch, 0); + assertSketchSize(valuesBlock, VARCHAR, SfmSketch.NON_PRIVATE_EPSILON, refSketch); + + refSketch = toSqlVarbinary(createStringSketch(8192, 24, 1, 100_000)); + assertCardinality(valuesBlock, VARCHAR, SfmSketch.NON_PRIVATE_EPSILON, 8192, refSketch, 0); + assertSketchSize(valuesBlock, VARCHAR, SfmSketch.NON_PRIVATE_EPSILON, 8192, refSketch); + + refSketch = toSqlVarbinary(createStringSketch(2048, 32, 1, 100_000)); + assertCardinality(valuesBlock, VARCHAR, SfmSketch.NON_PRIVATE_EPSILON, 2048, 32, refSketch, 0); + assertSketchSize(valuesBlock, VARCHAR, SfmSketch.NON_PRIVATE_EPSILON, 2048, 32, refSketch); + } + + @Test + public void testPrivateString() + { + Block valuesBlock = createStringSequenceBlock(1, 100_000); + + SqlVarbinary refSketch = toSqlVarbinary(createStringSketch(4096, 24, 1, 100_000)); + assertCardinality(valuesBlock, VARCHAR, 8, refSketch, 50_000); + assertSketchSize(valuesBlock, VARCHAR, 8, refSketch); + + refSketch = toSqlVarbinary(createStringSketch(8192, 24, 1, 100_000)); + assertCardinality(valuesBlock, VARCHAR, 8, 8192, refSketch, 50_000); + assertSketchSize(valuesBlock, VARCHAR, 8, 8192, refSketch); + + refSketch = toSqlVarbinary(createStringSketch(2048, 32, 1, 100_000)); + assertCardinality(valuesBlock, VARCHAR, 8, 2048, 32, refSketch, 50_000); + assertSketchSize(valuesBlock, VARCHAR, 8, 2048, 32, refSketch); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestSfmSketchStateFactory.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestSfmSketchStateFactory.java new file mode 100644 index 0000000000000..e471189397682 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestSfmSketchStateFactory.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.array.ObjectBigArray; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import org.openjdk.jol.info.ClassLayout; +import org.testng.annotations.Test; + +import static java.lang.String.format; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +public class TestSfmSketchStateFactory +{ + private final SfmSketchStateFactory factory = new SfmSketchStateFactory(); + + @Test + public void testCreateSingleStateEmpty() + { + SfmSketchState state = factory.createSingleState(); + assertNull(state.getSketch()); + assertEquals(state.getEstimatedSize(), ClassLayout.parseClass(SfmSketchStateFactory.SingleSfmSketchState.class).instanceSize()); + } + + @Test + public void testCreateSingleStatePresent() + { + SfmSketchState state = factory.createSingleState(); + SfmSketch sketch = SfmSketch.create(16, 16); + state.setSketch(sketch); + assertEquals(sketch, state.getSketch()); + } + + @Test + public void testCreateGroupedStateEmpty() + { + SfmSketchState state = factory.createGroupedState(); + assertNull(state.getSketch()); + int instanceSize = 40; + long bigObjectSize = (new ObjectBigArray<>()).sizeOf(); + assertTrue(state.getEstimatedSize() >= instanceSize + bigObjectSize, format("Estimated memory size was %d", state.getEstimatedSize())); + } + + @Test + public void testCreateGroupedStatePresent() + { + SfmSketchState state = factory.createGroupedState(); + assertTrue(state instanceof SfmSketchStateFactory.GroupedSfmSketchState); + SfmSketchStateFactory.GroupedSfmSketchState groupedState = (SfmSketchStateFactory.GroupedSfmSketchState) state; + + groupedState.setGroupId(1); + assertNull(state.getSketch()); + SfmSketch sketch1 = SfmSketch.create(16, 16); + groupedState.setSketch(sketch1); + assertEquals(state.getSketch(), sketch1); + + groupedState.setGroupId(2); + assertNull(state.getSketch()); + SfmSketch sketch2 = SfmSketch.create(32, 32); + groupedState.setSketch(sketch2); + assertEquals(state.getSketch(), sketch2); + + groupedState.setGroupId(1); + assertNotNull(state.getSketch()); + } + + @Test + public void testMemoryAccounting() + { + SfmSketchState state = factory.createGroupedState(); + long oldSize = state.getEstimatedSize(); + SfmSketch sketch1 = SfmSketch.create(16, 16); + + state.setSketch(sketch1); + assertEquals( + state.getEstimatedSize(), + oldSize + sketch1.getRetainedSizeInBytes(), + format( + "Expected old size %s plus new sketch sketch size to be equal than new estimate %s", + oldSize, + state.getEstimatedSize())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestSfmSketchStateSerializer.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestSfmSketchStateSerializer.java new file mode 100644 index 0000000000000..baf89586b0561 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestSfmSketchStateSerializer.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import com.facebook.presto.operator.aggregation.state.StateCompiler; +import com.facebook.presto.spi.function.AccumulatorStateFactory; +import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.type.SfmSketchType; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; + +public class TestSfmSketchStateSerializer +{ + @Test + public void testSerializeDeserialize() + { + AccumulatorStateFactory factory = StateCompiler.generateStateFactory(SfmSketchState.class); + AccumulatorStateSerializer serializer = StateCompiler.generateStateSerializer(SfmSketchState.class); + SfmSketchState state = factory.createSingleState(); + SfmSketch sketch = SfmSketch.create(16, 16); + state.setSketch(sketch); + state.setEpsilon(0.1); + + BlockBuilder builder = SfmSketchType.SFM_SKETCH.createBlockBuilder(null, 1); + serializer.serialize(state, builder); + Block block = builder.build(); + + state.setSketch(null); + serializer.deserialize(block, 0, state); + + assertNotNull(state.getSketch()); + assertEquals(state.getEpsilon(), 0.1); + } + + @Test + public void testSerializeDeserializeGrouped() + { + AccumulatorStateFactory factory = StateCompiler.generateStateFactory(SfmSketchState.class); + AccumulatorStateSerializer serializer = StateCompiler.generateStateSerializer(SfmSketchState.class); + SfmSketchStateFactory.GroupedSfmSketchState state = (SfmSketchStateFactory.GroupedSfmSketchState) factory.createGroupedState(); + double epsilon1 = 0.1; + double epsilon2 = 0.2; + SfmSketch sketch1 = SfmSketch.create(16, 16); + SfmSketch sketch2 = SfmSketch.create(32, 16); + // Add state to group 1 + state.setGroupId(1); + state.setSketch(sketch1); + state.setEpsilon(epsilon1); + // Add another state to group 2, to show that this doesn't affect the group under test (group 1) + state.setGroupId(2); + state.setSketch(sketch2); + state.setEpsilon(epsilon2); + // Return to group 1 + state.setGroupId(1); + + BlockBuilder builder = SfmSketchType.SFM_SKETCH.createBlockBuilder(null, 1); + serializer.serialize(state, builder); + Block block = builder.build(); + + // Assert the state of group 1 + state.setEpsilon(0.99); + serializer.deserialize(block, 0, state); + state.getSketch().cardinality(); + assertNotNull(state.getSketch()); + assertEquals(state.getEpsilon(), epsilon1); + // Verify nothing changed in group 2 + state.setGroupId(2); + assertNotNull(state.getSketch()); + assertEquals(state.getEpsilon(), epsilon2); + // Groups we did not touch are null + state.setGroupId(3); + assertNull(state.getSketch()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestBitmap.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestBitmap.java new file mode 100644 index 0000000000000..83b02eed0eff4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestBitmap.java @@ -0,0 +1,219 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation.sketch; + +import io.airlift.slice.SizeOf; +import org.openjdk.jol.info.ClassLayout; +import org.testng.annotations.Test; + +import java.util.BitSet; +import java.util.Random; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestBitmap +{ + private TestBitmap() {} + + @Test + public static void testRoundTrip() + { + byte[] bytes = randomBytes(100); + assertEquals(Bitmap.fromBytes(100 * 8, bytes).toBytes(), bytes); + } + + @Test + public static void testSetBit() + { + Bitmap bitmap = new Bitmap(24); + + // This should create the following bitmap: + // 00000011_00000101_01010101 + bitmap.setBit(0, true); + bitmap.setBit(1, true); + bitmap.setBit(8, true); + bitmap.setBit(10, true); + bitmap.setBit(16, true); + bitmap.setBit(18, true); + bitmap.setBit(20, true); + bitmap.setBit(22, true); + + byte[] bytes = bitmap.toBytes(); + assertEquals(bytes[0], 0b00000011); + assertEquals(bytes[1], 0b00000101); + assertEquals(bytes[2], 0b01010101); + + // Now clear bits in positions 10+ + // This bitmap should now be: + // 00000011_00000001 [_00000000] (the last byte will be truncated in toBytes()) + for (int i = 10; i < 24; i++) { + bitmap.setBit(i, false); + } + + bytes = bitmap.toBytes(); + assertEquals(bytes.length, 2); + assertEquals(bytes[0], 0b00000011); + assertEquals(bytes[1], 0b00000001); + } + + @Test + public static void testGetBit() + { + Bitmap bitmap = new Bitmap(4096); + + for (int i = 0; i < 4096; i++) { + bitmap.setBit(i, true); + assertTrue(bitmap.getBit(i)); + bitmap.setBit(i, false); + assertFalse(bitmap.getBit(i)); + } + } + + @Test + public void testGetBitCount() + { + int length = 1024; + Bitmap bitmap = new Bitmap(length); + assertEquals(bitmap.getBitCount(), 0); // all zeros at initialization + for (int i = 0; i < length; i++) { + bitmap.setBit(i, true); + assertEquals(bitmap.getBitCount(), i + 1); // i + 1 "true" bits + } + } + + @Test + public static void testFlipBit() + { + Bitmap bitmap = new Bitmap(4096); + + for (int i = 0; i < 4096; i++) { + bitmap.flipBit(i); + assertTrue(bitmap.getBit(i)); + bitmap.flipBit(i); + assertFalse(bitmap.getBit(i)); + bitmap.flipBit(i); + assertTrue(bitmap.getBit(i)); + } + } + + @Test + public static void testByteLength() + { + for (int length : new int[] {8, 800}) { + Bitmap bitmap = new Bitmap(length); + for (int i = 0; i < length; i++) { + bitmap.setBit(i, true); + assertEquals(bitmap.byteLength(), bitmap.toBytes().length); + } + } + } + + @Test + public static void testLength() + { + for (int i = 1; i <= 10; i++) { + Bitmap bitmap = new Bitmap(i * 8); + assertEquals(bitmap.length(), i * 8); + } + } + + @Test + public static void testRandomFlips() + { + Bitmap bitmap = new Bitmap(16); + + // Note: TestingDeterministicRandomizationStrategy flips deterministically if and only if probability >= 0.5. + + TestingDeterministicRandomizationStrategy randomizationStrategy = new TestingDeterministicRandomizationStrategy(); + bitmap.flipBit(0, 0.75, randomizationStrategy); + assertTrue(bitmap.getBit(0)); + bitmap.flipBit(0, 0.75, randomizationStrategy); + assertFalse(bitmap.getBit(0)); + bitmap.flipBit(0, 0.25, randomizationStrategy); + assertFalse(bitmap.getBit(0)); + + bitmap.flipAll(0.75, randomizationStrategy); + for (int i = 0; i < 16; i++) { + assertTrue(bitmap.getBit(i)); + } + + bitmap.flipAll(0.25, randomizationStrategy); + for (int i = 0; i < 16; i++) { + assertTrue(bitmap.getBit(i)); + } + } + + @Test + public static void testClone() + { + Bitmap bitmapA = Bitmap.fromBytes(100 * 8, randomBytes(100)); + Bitmap bitmapB = bitmapA.clone(); + + // all bits should match + for (int i = 0; i < 100 * 8; i++) { + assertEquals(bitmapA.getBit(i), bitmapB.getBit(i)); + } + + // but the bitmaps should point to different bits + bitmapA.flipBit(0); + assertEquals(bitmapA.getBit(0), !bitmapB.getBit(0)); + } + + @Test + public static void testOr() + { + Bitmap bitmapA = Bitmap.fromBytes(100 * 8, randomBytes(100)); + Bitmap bitmapB = Bitmap.fromBytes(100 * 8, randomBytes(100)); + Bitmap bitmapC = bitmapA.clone(); + bitmapC.or(bitmapB); + + for (int i = 0; i < 100 * 8; i++) { + assertEquals(bitmapC.getBit(i), bitmapA.getBit(i) | bitmapB.getBit(i)); + } + } + + @Test + public static void testXor() + { + Bitmap bitmapA = Bitmap.fromBytes(100 * 8, randomBytes(100)); + Bitmap bitmapB = Bitmap.fromBytes(100 * 8, randomBytes(100)); + Bitmap bitmapC = bitmapA.clone(); + bitmapC.xor(bitmapB); + + for (int i = 0; i < 100 * 8; i++) { + assertEquals(bitmapC.getBit(i), bitmapA.getBit(i) ^ bitmapB.getBit(i)); + } + } + + @Test + public static void testRetainedSize() + { + int instanceSizes = ClassLayout.parseClass(Bitmap.class).instanceSize() + ClassLayout.parseClass(BitSet.class).instanceSize(); + + // The underlying BitSet stores a long[] array of size length / 64, + // even though toBytes() returns a truncated array of bytes. + Bitmap bitmap = new Bitmap(1024); + assertEquals(bitmap.getRetainedSizeInBytes(), instanceSizes + SizeOf.sizeOfLongArray(1024 / 64)); + } + + private static byte[] randomBytes(int length) + { + byte[] bytes = new byte[length]; + Random random = new Random(); + random.nextBytes(bytes); + return bytes; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestSfmSketch.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestSfmSketch.java new file mode 100644 index 0000000000000..a4add1e1e7552 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestSfmSketch.java @@ -0,0 +1,389 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation.sketch; + +import io.airlift.slice.Slice; +import org.openjdk.jol.info.ClassLayout; +import org.testng.annotations.Test; + +import static io.airlift.slice.testing.SliceAssertions.assertSlicesEqual; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +public class TestSfmSketch +{ + @Test + public void testComputeIndex() + { + for (int indexBitLength : new int[]{6, 12, 18}) { + long index = 5L; + long hash = index << (Long.SIZE - indexBitLength); + assertEquals(SfmSketch.computeIndex(hash, indexBitLength), index); + } + } + + @Test + public void testIndexBitLength() + { + for (int i = 1; i < 20; i++) { + assertEquals(SfmSketch.indexBitLength((int) Math.pow(2, i)), i); + } + } + + @Test + public void testNumberOfTrailingZeros() + { + for (int indexBitLength : new int[]{6, 12, 18}) { + for (int i = 0; i < Long.SIZE - 1; i++) { + long hash = 1L << i; + assertEquals(SfmSketch.numberOfTrailingZeros(hash, indexBitLength), Math.min(i, Long.SIZE - indexBitLength)); + } + } + } + + @Test + public void testNumberOfBuckets() + { + for (int i = 1; i < 20; i++) { + assertEquals(SfmSketch.numberOfBuckets(i), Math.round(Math.pow(2, i))); + } + } + + @Test + public void testPowerOf2() + { + for (int i = 1; i < 20; i++) { + assertTrue(SfmSketch.isPowerOf2(Math.round(Math.pow(2, i)))); + assertFalse(SfmSketch.isPowerOf2(Math.round(Math.pow(2, i)) + 1)); + } + } + + @Test + public void testRoundTrip() + { + SfmSketch sketch = SfmSketch.create(4096, 24); + for (int i = 0; i < 100_000; i++) { + sketch.add(i); + } + sketch.enablePrivacy(2, new TestingSeededRandomizationStrategy(1)); + Slice serialized = sketch.serialize(); + SfmSketch unserialized = SfmSketch.deserialize(serialized); + assertSlicesEqual(serialized, unserialized.serialize()); + } + + @Test + public void testPrivacyEnabled() + { + SfmSketch sketch = SfmSketch.create(32, 24); + assertFalse(sketch.isPrivacyEnabled()); + sketch.enablePrivacy(SfmSketch.NON_PRIVATE_EPSILON); + assertFalse(sketch.isPrivacyEnabled()); + sketch.enablePrivacy(1.23); + assertTrue(sketch.isPrivacyEnabled()); + } + + @Test + public void testSerializedSize() + { + SfmSketch sketch = SfmSketch.create(4096, 24); + sketch.enablePrivacy(1.23); + assertEquals(sketch.estimatedSerializedSize(), sketch.serialize().length()); + } + + @Test + public void testRetainedSize() + { + SfmSketch sketch = SfmSketch.create(4096, 24); + sketch.enablePrivacy(4); + assertEquals(sketch.getRetainedSizeInBytes(), + ClassLayout.parseClass(SfmSketch.class).instanceSize() + + sketch.getBitmap().getRetainedSizeInBytes()); + } + + @Test + public void testBitmapSize() + { + int[] buckets = {32, 64, 512, 1024, 4096, 32768}; + int[] precisions = {1, 2, 3, 8, 24, 32}; + + for (int numberOfBuckets : buckets) { + for (int precision : precisions) { + SfmSketch sketch = SfmSketch.create(numberOfBuckets, precision); + assertEquals(sketch.getBitmap().length(), numberOfBuckets * precision); + } + } + } + + @Test + public void testMergeNonPrivate() + { + SfmSketch sketch = SfmSketch.create(4096, 24); + SfmSketch sketch2 = SfmSketch.create(4096, 24); + + // insert 100,000 non-negative integers + 100,000 negative integers + for (int i = 0; i < 100_000; i++) { + sketch.add(i); + sketch2.add(-i - 1); + } + Bitmap refBitmap = sketch.getBitmap().clone(); // clone old bitmap + sketch.mergeWith(sketch2); + + // The two bitmaps should be merged with OR, + // and the resulting bitmap is not private. + refBitmap.or(sketch2.getBitmap()); + assertEquals(sketch.getBitmap().toBytes(), refBitmap.toBytes()); + assertFalse(sketch.isPrivacyEnabled()); + } + + @Test + public void testMergePrivate() + { + SfmSketch sketch = SfmSketch.create(4096, 24); + SfmSketch sketch2 = SfmSketch.create(4096, 24); + + // insert 100,000 non-negative integers + 100,000 negative integers + for (int i = 0; i < 100_000; i++) { + sketch.add(i); + sketch2.add(-i - 1); + } + + Bitmap nonPrivateBitmap = sketch.getBitmap().clone(); + Bitmap nonPrivateBitmap2 = sketch2.getBitmap().clone(); + + sketch.enablePrivacy(3, new TestingSeededRandomizationStrategy(1)); + sketch2.enablePrivacy(4, new TestingSeededRandomizationStrategy(2)); + double p1 = sketch.getRandomizedResponseProbability(); + double p2 = sketch2.getRandomizedResponseProbability(); + + Bitmap refBitmap = sketch.getBitmap().clone(); // clone existing bitmap + refBitmap.or(sketch2.getBitmap()); // take an OR with sketch2, for later comparison + sketch.mergeWith(sketch2, new TestingSeededRandomizationStrategy(3)); + + // The resulting bitmap is a randomized merge equivalent to a noisy (not deterministic) OR. + // As a result, the bitmap should not equal an OR, but it should have roughly the same number + // of 1-bits as an OR that is flipped with the merged randomizedResponseProbability. + // The resulting merged sketch is private. + assertTrue(sketch.isPrivacyEnabled()); + assertEquals(sketch.getRandomizedResponseProbability(), SfmSketch.mergeRandomizedResponseProbabilities(p1, p2)); + assertNotEquals(sketch.getBitmap().toBytes(), refBitmap.toBytes()); + + int actualBitCount = sketch.getBitmap().getBitCount(); + Bitmap hypotheticalBitmap = nonPrivateBitmap.clone(); + hypotheticalBitmap.or(nonPrivateBitmap2); + hypotheticalBitmap.flipAll(sketch.getRandomizedResponseProbability(), new TestingSeededRandomizationStrategy(1)); + // The number of 1-bits in the merged sketch should approximately equal the number of 1-bits in our hypothetical bitmap. + assertEquals(hypotheticalBitmap.getBitCount(), actualBitCount, 100); + } + + @Test + public void testMergeMixed() + { + SfmSketch sketch = SfmSketch.create(4096, 24); + SfmSketch sketch2 = SfmSketch.create(4096, 24); + for (int i = 0; i < 100_000; i++) { + sketch.add(i); + sketch2.add(-i - 1); + } + sketch2.enablePrivacy(3, new TestingSeededRandomizationStrategy(1)); + Bitmap before = sketch.getBitmap().clone(); + sketch.mergeWith(sketch2, new TestingSeededRandomizationStrategy(2)); + + // The resulting sketch is private. + assertTrue(sketch.isPrivacyEnabled()); + + // A mixed-privacy merge is mathematically similar to a normal private merge, but + // it turns out that some bits are deterministic. In particular, the bits of the + // merged sketch corresponding to 0s in the non-private sketch should exactly match + // the private sketch. + for (int i = 0; i < before.length(); i++) { + if (!before.getBit(i)) { + assertEquals(sketch.getBitmap().getBit(i), sketch2.getBitmap().getBit(i)); + } + } + } + + @Test + public void testMergedProbabilities() + { + // should be symmetric + assertEquals(SfmSketch.mergeRandomizedResponseProbabilities(0.1, 0.2), SfmSketch.mergeRandomizedResponseProbabilities(0.2, 0.1)); + + // private + nonprivate = private + assertEquals(SfmSketch.mergeRandomizedResponseProbabilities(0, 0.1), 0.1); + assertEquals(SfmSketch.mergeRandomizedResponseProbabilities(0.15, 0), 0.15); + + // nonprivate + nonprivate = nonprivate + assertEquals(SfmSketch.mergeRandomizedResponseProbabilities(0.0, 0.0), 0.0); + + // private + private = private (noisier) + // In particular, according to https://arxiv.org/pdf/2302.02056.pdf, Theorem 4.8, two sketches + // with epsilon1 and epsilon2 should have a merged epsilonStar of: + // -log(e^-epsilon1 + e^-epsilon2 - e^-(epsilon1 + epsilon2)) + double epsilon1 = 1.2; + double epsilon2 = 3.4; + double p1 = SfmSketch.getRandomizedResponseProbability(epsilon1); + double p2 = SfmSketch.getRandomizedResponseProbability(epsilon2); + double epsilonStar = -Math.log(Math.exp(-epsilon1) + Math.exp(-epsilon2) - Math.exp(-(epsilon1 + epsilon2))); + double pStar = SfmSketch.getRandomizedResponseProbability(epsilonStar); + assertEquals(SfmSketch.mergeRandomizedResponseProbabilities(p1, p2), pStar, 1E-6); + // note: the merged sketch is noisier (higher probability of flipped bits) + assertTrue(pStar > Math.max(p1, p2)); + } + + @Test + public void testEmptySketchCardinality() + { + SfmSketch nonPrivateSketch = SfmSketch.create(4096, 24); + SfmSketch privateSketch = SfmSketch.create(4096, 24); + privateSketch.enablePrivacy(3, new TestingSeededRandomizationStrategy(1)); + + // Non-private should return exactly 0 + assertEquals(nonPrivateSketch.cardinality(), 0); + + // Private will be a noisy sketch, so it should return approximately zero, but will be rather noisy. + assertEquals(privateSketch.cardinality(), 0, 200); + } + + @Test + public void testSmallCardinality() + { + int[] ns = {1, 5, 10, 50, 100, 200, 500, 1000}; + + for (int n : ns) { + SfmSketch nonPrivateSketch = SfmSketch.create(4096, 24); + SfmSketch privateSketch = SfmSketch.create(4096, 24); + + for (int i = 0; i < n; i++) { + nonPrivateSketch.add(i); + privateSketch.add(i); + } + + privateSketch.enablePrivacy(3, new TestingSeededRandomizationStrategy(1)); + + // Non-private should actually be quite good for small numbers + assertEquals(nonPrivateSketch.cardinality(), n, Math.max(10, 0.1 * n)); + + // Private isn't quite as good... + assertEquals(privateSketch.cardinality(), n, 200); + } + } + + @Test + public void testActualCardinalityEstimates() + { + // Note: this is slow for cardinalities beyond, say, 1 million. See `testSimulatedCardinalityEstimates` below. + int[] magnitudes = {4, 5, 6}; + double[] epsilons = {2, 4, SfmSketch.NON_PRIVATE_EPSILON}; + for (int mag : magnitudes) { + int n = (int) Math.pow(10, mag); + for (double eps : epsilons) { + SfmSketch sketch = SfmSketch.create(4096, 24); + for (int i = 0; i < n; i++) { + sketch.add(i); + } + sketch.enablePrivacy(eps, new TestingSeededRandomizationStrategy(1)); + assertEquals(sketch.cardinality(), n, n * 0.05); // answers should be accurate to within 5% (arbitrary) + } + } + } + + @Test + public void testSimulatedCardinalityEstimates() + { + // Instead of creating sketches by adding items, we simulate them for fast testing of huge cardinalities. + // For reference, 10^33 is one decillion. + // The goal here is to test general functionality and numerical stability. + int[] magnitudes = {6, 9, 12, 15, 18, 21, 24, 27, 30, 33}; + double[] epsilons = {4, SfmSketch.NON_PRIVATE_EPSILON}; + for (int mag : magnitudes) { + int n = (int) Math.pow(10, mag); + for (double eps : epsilons) { + SfmSketch sketch = createSketchWithTargetCardinality(4096, 24, eps, n); + assertEquals(sketch.cardinality(), n, n * 0.1); + } + } + } + + @Test + public void testMergedCardinalities() + { + double[] epsilons = {3, 4, SfmSketch.NON_PRIVATE_EPSILON}; + + // Test each pair of epsilons + // This gives us equal private epsilons, unequal private epsilons, mixed private and nonprivate, and totally nonprivate + for (double eps1 : epsilons) { + for (double eps2 : epsilons) { + SfmSketch sketch = SfmSketch.create(4096, 24); + SfmSketch sketch2 = SfmSketch.create(4096, 24); + // insert 300,000 positive integers and 200,000 negative integers + for (int i = 0; i < 300_000; i++) { + sketch.add(i + 1); + if (i < 200_000) { + sketch2.add(-i); + } + } + + sketch.enablePrivacy(eps1, new TestingSeededRandomizationStrategy(1)); + sketch2.enablePrivacy(eps2, new TestingSeededRandomizationStrategy(2)); + sketch.mergeWith(sketch2); + assertEquals(sketch.cardinality(), 500_000, 500_000 * 0.1); + } + } + } + + @Test + public void testEnablePrivacy() + { + SfmSketch sketch = SfmSketch.create(4096, 24); + double epsilon = 4; + + for (int i = 0; i < 100_000; i++) { + sketch.add(i); + } + + long cardinalityBefore = sketch.cardinality(); + sketch.enablePrivacy(epsilon, new TestingSeededRandomizationStrategy(1)); + long cardinalityAfter = sketch.cardinality(); + + // Randomized response probability should reflect the new (private) epsilon + assertEquals(sketch.getRandomizedResponseProbability(), SfmSketch.getRandomizedResponseProbability(epsilon)); + assertTrue(sketch.isPrivacyEnabled()); + + // Cardinality should remain approximately the same + assertEquals(cardinalityAfter, cardinalityBefore, cardinalityBefore * 0.1); + } + + private static SfmSketch createSketchWithTargetCardinality(int numberOfBuckets, int precision, double epsilon, int cardinality) + { + // Building a sketch by adding items is really slow (O(n)) if you want to test billions/trillions/quadrillions/etc. + // Simulating the sketch is much faster (O(buckets * precision)). + RandomizationStrategy randomizationStrategy = new TestingSeededRandomizationStrategy(1); + SfmSketch sketch = SfmSketch.create(numberOfBuckets, precision); + Bitmap bitmap = sketch.getBitmap(); + double c1 = sketch.getOnProbability(); + double c2 = sketch.getOnProbability() - sketch.getRandomizedResponseProbability(); + + for (int l = 0; l < precision; l++) { + double p = c1 - c2 * Math.pow(1 - Math.pow(2, -(l + 1)) / numberOfBuckets, cardinality); + for (int b = 0; b < numberOfBuckets; b++) { + bitmap.setBit(sketch.getBitLocation(b, l), randomizationStrategy.nextBoolean(p)); + } + } + + sketch.enablePrivacy(epsilon, randomizationStrategy); + return sketch; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestingDeterministicRandomizationStrategy.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestingDeterministicRandomizationStrategy.java new file mode 100644 index 0000000000000..16e2bee23ff41 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestingDeterministicRandomizationStrategy.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation.sketch; + +/** + * Non-random numbers for testing + */ +public class TestingDeterministicRandomizationStrategy + extends RandomizationStrategy +{ + public TestingDeterministicRandomizationStrategy() {} + + public double nextDouble() + { + return 0.5; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestingSeededRandomizationStrategy.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestingSeededRandomizationStrategy.java new file mode 100644 index 0000000000000..aa43c76060640 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/TestingSeededRandomizationStrategy.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.noisyaggregation.sketch; + +import java.util.Random; + +/** + * Seeded random numbers for testing + */ +public class TestingSeededRandomizationStrategy + extends RandomizationStrategy +{ + private final Random random; + + public TestingSeededRandomizationStrategy(long seed) + { + this.random = new Random(seed); + } + + public double nextDouble() + { + return random.nextDouble(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestSfmSketchFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestSfmSketchFunctions.java new file mode 100644 index 0000000000000..18710f30d6a34 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestSfmSketchFunctions.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.operator.aggregation.noisyaggregation.sketch.SfmSketch; +import com.google.common.io.BaseEncoding; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; + +public class TestSfmSketchFunctions + extends AbstractTestFunctions +{ + @Test + public void testCardinality() + { + SfmSketch sketch = createSketch(1, 10_000, 4); + assertEquals(SfmSketchFunctions.cardinality(sketch.serialize()), sketch.cardinality()); + } + + @Test + public void testEmptyApproxSet() + { + // with no privacy (epsilon = infinity), an empty approx set should return 0 cardinality + assertFunction("cardinality(noisy_empty_approx_set_sfm(infinity()))", BIGINT, 0L); + assertFunction("cardinality(noisy_empty_approx_set_sfm(infinity(), 4096))", BIGINT, 0L); + assertFunction("cardinality(noisy_empty_approx_set_sfm(infinity(), 4096, 24))", BIGINT, 0L); + } + + @Test + public void testCastRoundTrip() + { + assertFunction("cardinality(CAST(CAST(noisy_empty_approx_set_sfm(infinity()) AS VARBINARY) AS SFMSKETCH))", BIGINT, 0L); + } + + @Test + public void testMergeNullArray() + { + assertFunction("merge_sfm(ARRAY[NULL, NULL, NULL]) IS NULL", BOOLEAN, true); + } + + @Test + public void testMergeEmptyArray() + { + // calling with an empty array should return NULL + assertFunction("merge_sfm(ARRAY[]) IS NULL", BOOLEAN, true); + } + + @Test + public void testMergeSingleArray() + { + // merging a single SFM sketch should simply return the sketch + String sketchProjection = getSketchProjection(createSketch(1, 10_000, 3)); + assertFunction("cardinality(merge_sfm(ARRAY[" + sketchProjection + "])) = cardinality(" + sketchProjection + ")", BOOLEAN, true); + } + + @Test + public void testMergeManyArrays() + { + // merging many sketches should return a single merged sketch + // (using non-private sketches here for a deterministic test) + String sketchProjection1 = getSketchProjection(createSketch(1, 50, SfmSketch.NON_PRIVATE_EPSILON)); + String sketchProjection2 = getSketchProjection(createSketch(51, 200, SfmSketch.NON_PRIVATE_EPSILON)); + String sketchProjection3 = getSketchProjection(createSketch(100, 300, SfmSketch.NON_PRIVATE_EPSILON)); + String sketchProjectionMerged = getSketchProjection(createSketch(1, 300, SfmSketch.NON_PRIVATE_EPSILON)); + String arrayProjection = "ARRAY[" + sketchProjection1 + ", " + sketchProjection2 + ", " + sketchProjection3 + "]"; + assertFunction("CAST(merge_sfm(" + arrayProjection + ") AS VARBINARY) = CAST(" + sketchProjectionMerged + " AS VARBINARY)", BOOLEAN, true); + } + + private SfmSketch createSketch(int start, int end, double epsilon) + { + SfmSketch sketch = SfmSketch.create(2048, 16); + for (int i = start; i <= end; i++) { + sketch.add(i); + } + + if (epsilon < SfmSketch.NON_PRIVATE_EPSILON) { + sketch.enablePrivacy(epsilon); + } + + return sketch; + } + + private String getSketchProjection(SfmSketch sketch) + { + byte[] binary = sketch.serialize().getBytes(); + String encoded = BaseEncoding.base16().lowerCase().encode(binary); + return "CAST(X'" + encoded + "' AS SFMSKETCH)"; + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestNoisyAggregations.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestNoisyAggregations.java index 4f4d6ba4056f1..bb1ba3f7905f4 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestNoisyAggregations.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestNoisyAggregations.java @@ -117,6 +117,25 @@ public void testNoisyAvgGaussianZeroNoiseScaleRandomSeedVsNormalCount() assertQueryWithSingleDoubleRow("SELECT noisy_avg_gaussian(nationkey, 0, 10) FROM nation", "SELECT avg(nationkey) FROM nation"); // INTEGER } + @Test + public void testNoisyApproxSetVsApproxDistinct() + { + assertQuery("SELECT noisy_approx_distinct_sfm(linenumber, infinity()) FROM lineitem", + "SELECT cardinality(noisy_approx_set_sfm(linenumber, infinity())) FROM lineitem"); + assertQuery("SELECT noisy_approx_distinct_sfm(linenumber, infinity(), 2048) FROM lineitem", + "SELECT cardinality(noisy_approx_set_sfm(linenumber, infinity(), 2048)) FROM lineitem"); + assertQuery("SELECT noisy_approx_distinct_sfm(linenumber, infinity(), 8192, 32) FROM lineitem", + "SELECT cardinality(noisy_approx_set_sfm(linenumber, infinity(), 8192, 32)) FROM lineitem"); + } + + @Test + public void testNoisyApproxSetMergedVsApproxDistinct() + { + assertQuery("SELECT cardinality(merge(sketch)) FROM " + + "(SELECT noisy_approx_set_sfm(linenumber, infinity()) AS sketch FROM lineitem GROUP BY mod(linenumber, 10))", + "SELECT noisy_approx_distinct_sfm(linenumber, infinity()) FROM lineitem"); + } + private void assertQueryWithSingleDoubleRow(@Language("SQL") String actual, @Language("SQL") String expected) { MaterializedResult actualResult = computeActual(actual);