diff --git a/presto-docs/src/main/sphinx/functions/aggregate.rst b/presto-docs/src/main/sphinx/functions/aggregate.rst index 9517c6088856c..1902e19795b1e 100644 --- a/presto-docs/src/main/sphinx/functions/aggregate.rst +++ b/presto-docs/src/main/sphinx/functions/aggregate.rst @@ -564,6 +564,7 @@ To find the `ROC curve @@ -39,36 +36,9 @@ public Type getSerializedType() public void serialize(DifferentialEntropyState state, BlockBuilder output) { DifferentialEntropyStateStrategy strategy = state.getStrategy(); - - if (strategy == null) { - SliceOutput sliceOut = Slices.allocate(SizeOf.SIZE_OF_INT).getOutput(); - sliceOut.appendInt(0); - VARBINARY.writeSlice(output, sliceOut.getUnderlyingSlice()); - return; - } - - int requiredBytes = SizeOf.SIZE_OF_INT + // Method - strategy.getRequiredBytesForSerialization(); // stateStrategy; - + int requiredBytes = DifferentialEntropyStateStrategy.getRequiredBytesForSerialization(strategy); SliceOutput sliceOut = Slices.allocate(requiredBytes).getOutput(); - - if (strategy instanceof UnweightedReservoirSampleStateStrategy) { - sliceOut.appendInt(1); - } - else if (strategy instanceof WeightedReservoirSampleStateStrategy) { - sliceOut.appendInt(2); - } - else if (strategy instanceof FixedHistogramMleStateStrategy) { - sliceOut.appendInt(3); - } - else if (strategy instanceof FixedHistogramJacknifeStateStrategy) { - sliceOut.appendInt(4); - } - else { - verify(false, format("Strategy cannot be serialized: %s", strategy.getClass().getSimpleName())); - } - - strategy.serialize(sliceOut); + DifferentialEntropyStateStrategy.serialize(strategy, sliceOut); VARBINARY.writeSlice(output, sliceOut.getUnderlyingSlice()); } @@ -79,25 +49,9 @@ public void deserialize( DifferentialEntropyState state) { SliceInput input = VARBINARY.getSlice(block, index).getInput(); - int method = input.readInt(); - switch (method) { - case 0: - verify(state.getStrategy() == null, "strategy is not null for null method"); - return; - case 1: - state.setStrategy(UnweightedReservoirSampleStateStrategy.deserialize(input)); - return; - case 2: - state.setStrategy(WeightedReservoirSampleStateStrategy.deserialize(input)); - return; - case 3: - state.setStrategy(FixedHistogramMleStateStrategy.deserialize(input)); - return; - case 4: - state.setStrategy(FixedHistogramJacknifeStateStrategy.deserialize(input)); - return; - default: - verify(false, format("Unknown method code when deserializing: %s", method)); + DifferentialEntropyStateStrategy strategy = DifferentialEntropyStateStrategy.deserialize(input); + if (strategy != null) { + state.setStrategy(strategy); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/DifferentialEntropyStateStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/DifferentialEntropyStateStrategy.java index 3288b7a731dc6..957952fd960cb 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/DifferentialEntropyStateStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/DifferentialEntropyStateStrategy.java @@ -13,8 +13,14 @@ */ package com.facebook.presto.operator.aggregation.differentialentropy; +import com.facebook.presto.spi.PrestoException; +import com.google.common.annotations.VisibleForTesting; +import io.airlift.slice.SizeOf; +import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.google.common.base.Verify.verify; import static java.lang.String.format; /** @@ -25,13 +31,114 @@ public interface DifferentialEntropyStateStrategy extends Cloneable { - void add(double sample, double weight); + @VisibleForTesting + String FIXED_HISTOGRAM_MLE_METHOD_NAME = "fixed_histogram_mle"; + @VisibleForTesting + String FIXED_HISTOGRAM_JACKNIFE_METHOD_NAME = "fixed_histogram_jacknife"; + + static DifferentialEntropyStateStrategy getStrategy( + DifferentialEntropyStateStrategy strategy, + long size, + double sample, + double weight, + String method, + double min, + double max) + { + if (strategy == null) { + switch (method) { + case DifferentialEntropyStateStrategy.FIXED_HISTOGRAM_MLE_METHOD_NAME: + strategy = new FixedHistogramMleStateStrategy(size, min, max); + break; + case DifferentialEntropyStateStrategy.FIXED_HISTOGRAM_JACKNIFE_METHOD_NAME: + strategy = new FixedHistogramJacknifeStateStrategy(size, min, max); + break; + default: + throw new PrestoException( + INVALID_FUNCTION_ARGUMENT, + format("In differential_entropy UDF, invalid method: %s", method)); + } + } + else { + switch (method) { + case DifferentialEntropyStateStrategy.FIXED_HISTOGRAM_MLE_METHOD_NAME: + if (!(strategy instanceof FixedHistogramMleStateStrategy)) { + throw new PrestoException( + INVALID_FUNCTION_ARGUMENT, + format("In differential_entropy, strategy class is not compatible with entropy method: %s %s", strategy.getClass().getSimpleName(), method)); + } + break; + case DifferentialEntropyStateStrategy.FIXED_HISTOGRAM_JACKNIFE_METHOD_NAME: + if (!(strategy instanceof FixedHistogramJacknifeStateStrategy)) { + throw new PrestoException( + INVALID_FUNCTION_ARGUMENT, + format("In differential_entropy, strategy class is not compatible with entropy method: %s %s", strategy.getClass().getSimpleName(), method)); + } + break; + default: + throw new PrestoException( + INVALID_FUNCTION_ARGUMENT, + format("In differential_entropy, unknown entropy method: %s", method)); + } + } + strategy.validateParameters(size, sample, weight, min, max); + return strategy; + } + + static DifferentialEntropyStateStrategy getStrategy( + DifferentialEntropyStateStrategy strategy, + long size, + double sample, + double weight) + { + if (strategy == null) { + strategy = new WeightedReservoirSampleStateStrategy(size); + } + else { + verify(strategy instanceof WeightedReservoirSampleStateStrategy, + format("In differential entropy, expected WeightedReservoirSampleStateStrategy, got: %s", strategy.getClass().getSimpleName())); + } + strategy.validateParameters(size, sample, weight); + return strategy; + } + + static DifferentialEntropyStateStrategy getStrategy( + DifferentialEntropyStateStrategy strategy, + long size, + double sample) + { + if (strategy == null) { + strategy = new UnweightedReservoirSampleStateStrategy(size); + } + else { + verify(strategy instanceof UnweightedReservoirSampleStateStrategy, + format("In differential entropy, expected UnweightedReservoirSampleStateStrategy, got: %s", strategy.getClass().getSimpleName())); + } + return strategy; + } + + default void add(double sample) + { + verify(false, format("Unweighted unsupported for type: %s", getClass().getSimpleName())); + } + + default void add(double sample, double weight) + { + verify(false, format("Weighted unsupported for type: %s", getClass().getSimpleName())); + } double calculateEntropy(); long getEstimatedSize(); - int getRequiredBytesForSerialization(); + static int getRequiredBytesForSerialization(DifferentialEntropyStateStrategy strategy) + { + return SizeOf.SIZE_OF_INT + // magic hash + SizeOf.SIZE_OF_INT + // method + (strategy == null ? 0 : strategy.getRequiredBytesForSpecificSerialization()); + } + + int getRequiredBytesForSpecificSerialization(); void serialize(SliceOutput out); @@ -39,21 +146,85 @@ public interface DifferentialEntropyStateStrategy DifferentialEntropyStateStrategy clone(); - default void validateParameters(long bucketCount, double sample, double weight, double min, double max) + default void validateParameters(long size, double sample, double weight, double min, double max) { throw new UnsupportedOperationException( format("In differential_entropy UDF, unsupported arguments for type: %s", getClass().getSimpleName())); } - default void validateParameters(long bucketCount, double sample, double weight) + default void validateParameters(long size, double sample, double weight) { throw new UnsupportedOperationException( format("In differential_entropy UDF, unsupported arguments for type: %s", getClass().getSimpleName())); } - default void validateParameters(long bucketCount, double sample) + default void validateParameters(long size, double sample) { throw new UnsupportedOperationException( format("In differential_entropy UDF, unsupported arguments for type: %s", getClass().getSimpleName())); } + + static void serialize(DifferentialEntropyStateStrategy strategy, SliceOutput sliceOut) + { + sliceOut.appendInt(DifferentialEntropyStateStrategy.class.getSimpleName().hashCode()); + if (strategy == null) { + sliceOut.appendInt(0); + return; + } + + if (strategy instanceof UnweightedReservoirSampleStateStrategy) { + sliceOut.appendInt(1); + } + else if (strategy instanceof WeightedReservoirSampleStateStrategy) { + sliceOut.appendInt(2); + } + else if (strategy instanceof FixedHistogramMleStateStrategy) { + sliceOut.appendInt(3); + } + else if (strategy instanceof FixedHistogramJacknifeStateStrategy) { + sliceOut.appendInt(4); + } + else { + verify(false, format("Strategy cannot be serialized: %s", strategy.getClass().getSimpleName())); + } + + strategy.serialize(sliceOut); + } + + static DifferentialEntropyStateStrategy deserialize(SliceInput input) + { + verify( + input.readInt() == DifferentialEntropyStateStrategy.class.getSimpleName().hashCode(), + "magic failed"); + int method = input.readInt(); + switch (method) { + case 0: + return null; + case 1: + return UnweightedReservoirSampleStateStrategy.deserialize(input); + case 2: + return WeightedReservoirSampleStateStrategy.deserialize(input); + case 3: + return FixedHistogramMleStateStrategy.deserialize(input); + case 4: + return FixedHistogramJacknifeStateStrategy.deserialize(input); + default: + verify(false, format("In differential_entropy UDF, Unknown method code when deserializing: %s", method)); + return null; + } + } + + static void combine( + DifferentialEntropyStateStrategy strategy, + DifferentialEntropyStateStrategy otherStrategy) + { + verify(strategy.getClass() == otherStrategy.getClass(), + format("In combine, %s != %s", strategy.getClass().getSimpleName(), otherStrategy.getClass().getSimpleName())); + + strategy.mergeWith(otherStrategy); + } + + DifferentialEntropyStateStrategy cloneEmpty(); + + double getTotalPopulationWeight(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/EntropyCalculations.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/EntropyCalculations.java index bac556c9b0793..479a6b99d7ee6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/EntropyCalculations.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/EntropyCalculations.java @@ -15,6 +15,7 @@ import java.util.Arrays; +import static com.google.common.base.Verify.verify; import static java.lang.Math.toIntExact; public class EntropyCalculations @@ -24,7 +25,7 @@ private EntropyCalculations() {} /** * @implNote Based on Alizadeh Noughabi, Hadi & Arghami, N. (2010). "A New Estimator of Entropy". */ - public static double calculateFromSamples(double[] samples) + public static double calculateFromSamplesUsingVasicek(double[] samples) { if (samples.length == 0) { return Double.NaN; @@ -42,4 +43,10 @@ public static double calculateFromSamples(double[] samples) } return entropy / n / Math.log(2); } + + static double calculateEntropyFromHistogramAggregates(double width, double sumWeight, double sumWeightLogWeight) + { + verify(sumWeight > 0.0); + return Math.max((Math.log(width * sumWeight) - sumWeightLogWeight / sumWeight) / Math.log(2.0), 0.0); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/FixedHistogramJacknifeStateStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/FixedHistogramJacknifeStateStrategy.java index c9eb31f0fa3fe..3d9316b60149f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/FixedHistogramJacknifeStateStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/FixedHistogramJacknifeStateStrategy.java @@ -19,8 +19,8 @@ import java.util.Map; +import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateEntropyFromHistogramAggregates; import static com.facebook.presto.operator.aggregation.differentialentropy.FixedHistogramStateStrategyUtils.getXLogX; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.Streams.stream; import static java.lang.Math.toIntExact; import static java.util.stream.Collectors.groupingBy; @@ -81,6 +81,14 @@ public void add(double value, double weight) histogram.add(value, weight); } + @Override + public double getTotalPopulationWeight() + { + return stream(histogram.iterator()) + .mapToDouble(FixedDoubleBreakdownHistogram.Bucket::getWeight) + .sum(); + } + @Override public double calculateEntropy() { @@ -98,7 +106,7 @@ public double calculateEntropy() double sumWeightLogWeight = bucketWeights.values().stream().mapToDouble(w -> w == 0.0 ? 0.0 : w * Math.log(w)).sum(); - double entropy = n * calculateEntropy(histogram.getWidth(), sumWeight, sumWeightLogWeight); + double entropy = n * calculateEntropyFromHistogramAggregates(histogram.getWidth(), sumWeight, sumWeightLogWeight); for (FixedDoubleBreakdownHistogram.Bucket bucketWeight : histogram) { double weight = bucketWeights.get(bucketWeight.getLeft()); if (weight > 0.0) { @@ -130,17 +138,11 @@ private static double getHoldOutEntropy( double holdoutSumWeightLogWeight = sumWeightLogWeight - getXLogX(bucketWeight) + getXLogX(holdoutBucketWeight); double holdoutEntropy = entryMultiplicity * (n - 1) * - calculateEntropy(width, holdoutSumWeight, holdoutSumWeightLogWeight) / + calculateEntropyFromHistogramAggregates(width, holdoutSumWeight, holdoutSumWeightLogWeight) / n; return holdoutEntropy; } - private static double calculateEntropy(double width, double sumWeight, double sumWeightLogWeight) - { - verify(sumWeight > 0.0); - return Math.max((Math.log(width * sumWeight) - sumWeightLogWeight / sumWeight) / Math.log(2.0), 0.0); - } - @Override public long getEstimatedSize() { @@ -148,7 +150,7 @@ public long getEstimatedSize() } @Override - public int getRequiredBytesForSerialization() + public int getRequiredBytesForSpecificSerialization() { return histogram.getRequiredBytesForSerialization(); } @@ -170,4 +172,10 @@ public DifferentialEntropyStateStrategy clone() { return new FixedHistogramJacknifeStateStrategy(this); } + + @Override + public DifferentialEntropyStateStrategy cloneEmpty() + { + return new FixedHistogramJacknifeStateStrategy(histogram.getBucketCount(), histogram.getMin(), histogram.getMax()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/FixedHistogramMleStateStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/FixedHistogramMleStateStrategy.java index 049221ca936f1..6edd2bc7865cc 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/FixedHistogramMleStateStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/FixedHistogramMleStateStrategy.java @@ -18,6 +18,7 @@ import io.airlift.slice.SliceOutput; import static com.facebook.presto.operator.aggregation.differentialentropy.FixedHistogramStateStrategyUtils.getXLogX; +import static com.google.common.collect.Streams.stream; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -74,6 +75,14 @@ public void add(double sample, double weight) histogram.add(sample, weight); } + @Override + public double getTotalPopulationWeight() + { + return stream(histogram.iterator()) + .mapToDouble(FixedDoubleHistogram.Bucket::getWeight) + .sum(); + } + @Override public double calculateEntropy() { @@ -99,7 +108,7 @@ public long getEstimatedSize() } @Override - public int getRequiredBytesForSerialization() + public int getRequiredBytesForSpecificSerialization() { return histogram.getRequiredBytesForSerialization(); } @@ -128,4 +137,10 @@ public DifferentialEntropyStateStrategy clone() { return new FixedHistogramMleStateStrategy(this); } + + @Override + public DifferentialEntropyStateStrategy cloneEmpty() + { + return new FixedHistogramMleStateStrategy(histogram.getBucketCount(), histogram.getMin(), histogram.getMax()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedReservoirSampleStateStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedReservoirSampleStateStrategy.java index 8118adfdafe48..b1ab327501783 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedReservoirSampleStateStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedReservoirSampleStateStrategy.java @@ -18,7 +18,7 @@ import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; -import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateFromSamples; +import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateFromSamplesUsingVasicek; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static java.lang.Math.toIntExact; import static java.lang.String.format; @@ -87,20 +87,21 @@ public void mergeWith(DifferentialEntropyStateStrategy other) } @Override - public void add(double value, double weight) + public void add(double value) { - if (weight != 1.0) { - throw new PrestoException( - INVALID_FUNCTION_ARGUMENT, - format("In differential_entropy UDF, weight should be 1.0: %s", weight)); - } reservoir.add(value); } + @Override + public double getTotalPopulationWeight() + { + return reservoir.getTotalPopulationCount(); + } + @Override public double calculateEntropy() { - return calculateFromSamples(reservoir.getSamples()); + return calculateFromSamplesUsingVasicek(reservoir.getSamples()); } @Override @@ -110,7 +111,7 @@ public long getEstimatedSize() } @Override - public int getRequiredBytesForSerialization() + public int getRequiredBytesForSpecificSerialization() { return reservoir.getRequiredBytesForSerialization(); } @@ -131,4 +132,10 @@ public DifferentialEntropyStateStrategy clone() { return new UnweightedReservoirSampleStateStrategy(this); } + + @Override + public DifferentialEntropyStateStrategy cloneEmpty() + { + return new UnweightedReservoirSampleStateStrategy(reservoir.getMaxSamples()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedReservoirSampleStateStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedReservoirSampleStateStrategy.java index 0ce2d14536f29..e65b81ef845df 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedReservoirSampleStateStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedReservoirSampleStateStrategy.java @@ -18,7 +18,7 @@ import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; -import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateFromSamples; +import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateFromSamplesUsingVasicek; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.google.common.base.Verify.verify; import static java.lang.String.format; @@ -85,10 +85,16 @@ public void add(double value, double weight) reservoir.add(value, weight); } + @Override + public double getTotalPopulationWeight() + { + return reservoir.getTotalPopulationWeight(); + } + @Override public double calculateEntropy() { - return calculateFromSamples(reservoir.getSamples()); + return calculateFromSamplesUsingVasicek(reservoir.getSamples()); } @Override @@ -98,7 +104,7 @@ public long getEstimatedSize() } @Override - public int getRequiredBytesForSerialization() + public int getRequiredBytesForSpecificSerialization() { return reservoir.getRequiredBytesForSerialization(); } @@ -119,4 +125,10 @@ public DifferentialEntropyStateStrategy clone() { return new WeightedReservoirSampleStateStrategy(this); } + + @Override + public DifferentialEntropyStateStrategy cloneEmpty() + { + return new WeightedReservoirSampleStateStrategy(reservoir.getMaxSamples()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/fixedhistogram/FixedDoubleBreakdownHistogram.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/fixedhistogram/FixedDoubleBreakdownHistogram.java index 91e70d7cdb221..e1960929badce 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/fixedhistogram/FixedDoubleBreakdownHistogram.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/fixedhistogram/FixedDoubleBreakdownHistogram.java @@ -45,9 +45,9 @@ public class FixedDoubleBreakdownHistogram { private static final int INSTANCE_SIZE = ClassLayout.parseClass(FixedDoubleBreakdownHistogram.class).instanceSize(); - private int bucketCount; - private double min; - private double max; + private final int bucketCount; + private final double min; + private final double max; private int[] indices; private double[] weights; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/fixedhistogram/FixedDoubleHistogram.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/fixedhistogram/FixedDoubleHistogram.java index ce67335423a83..504a07ac55973 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/fixedhistogram/FixedDoubleHistogram.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/fixedhistogram/FixedDoubleHistogram.java @@ -68,9 +68,9 @@ public Bucket(double left, double right, double weight) private final double[] weights; - private int bucketCount; - private double min; - private double max; + private final int bucketCount; + private final double min; + private final double max; public FixedDoubleHistogram(int bucketCount, double min, double max) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/UnweightedDoubleReservoirSample.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/UnweightedDoubleReservoirSample.java index b368cc2d032f1..57d527db14a27 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/UnweightedDoubleReservoirSample.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/UnweightedDoubleReservoirSample.java @@ -115,6 +115,11 @@ public void mergeWith(UnweightedDoubleReservoirSample other) samples = merged; } + public int getTotalPopulationCount() + { + return seenCount; + } + @Override public UnweightedDoubleReservoirSample clone() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/WeightedDoubleReservoirSample.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/WeightedDoubleReservoirSample.java index fa24428b7f398..102d1bef05155 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/WeightedDoubleReservoirSample.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/WeightedDoubleReservoirSample.java @@ -35,6 +35,7 @@ public class WeightedDoubleReservoirSample private int count; private double[] samples; private double[] weights; + private double totalPopulationWeight; public WeightedDoubleReservoirSample(int maxSamples) { @@ -52,13 +53,15 @@ private WeightedDoubleReservoirSample(WeightedDoubleReservoirSample other) this.count = other.count; this.samples = Arrays.copyOf(other.samples, other.samples.length); this.weights = Arrays.copyOf(other.weights, other.weights.length); + this.totalPopulationWeight = other.totalPopulationWeight; } - private WeightedDoubleReservoirSample(int count, double[] samples, double[] weights) + private WeightedDoubleReservoirSample(int count, double[] samples, double[] weights, double totalPopulationWeight) { this.count = count; this.samples = requireNonNull(samples, "samples is null"); this.weights = requireNonNull(weights, "weights is null"); + this.totalPopulationWeight = totalPopulationWeight; } public long getMaxSamples() @@ -69,6 +72,7 @@ public long getMaxSamples() public void add(double sample, double weight) { checkArgument(weight >= 0, format("Weight %s cannot be negative", weight)); + totalPopulationWeight += weight; double adjustedWeight = Math.pow( ThreadLocalRandom.current().nextDouble(), 1.0 / weight); @@ -96,6 +100,7 @@ private void addWithAdjustedWeight(double sample, double adjustedWeight) public void mergeWith(WeightedDoubleReservoirSample other) { + totalPopulationWeight += other.totalPopulationWeight; for (int i = 0; i < other.count; i++) { addWithAdjustedWeight(other.samples[i], other.weights[i]); } @@ -112,12 +117,6 @@ public double[] getSamples() return Arrays.copyOf(samples, count); } - private void checkArguments() - { - checkArgument(samples.length > 0, "Number of reservoir samples must be strictly positive"); - checkArgument(count <= samples.length, "Size must be at most number of samples"); - } - private void swap(int i, int j) { double tmpElement = samples[i]; @@ -182,7 +181,8 @@ public static WeightedDoubleReservoirSample deserialize(SliceInput input) input.readBytes(Slices.wrappedDoubleArray(samples), count * SizeOf.SIZE_OF_DOUBLE); double[] weights = new double[maxSamples]; input.readBytes(Slices.wrappedDoubleArray(weights), count * SizeOf.SIZE_OF_DOUBLE); - return new WeightedDoubleReservoirSample(count, samples, weights); + double totalPopulationWeight = input.readDouble(); + return new WeightedDoubleReservoirSample(count, samples, weights, totalPopulationWeight); } public void serialize(SliceOutput output) @@ -195,12 +195,14 @@ public void serialize(SliceOutput output) for (int i = 0; i < count; i++) { output.appendDouble(weights[i]); } + output.appendDouble(totalPopulationWeight); } public int getRequiredBytesForSerialization() { return SizeOf.SIZE_OF_INT + // count - SizeOf.SIZE_OF_INT + 2 * SizeOf.SIZE_OF_DOUBLE * Math.min(count, samples.length); // samples, weights + SizeOf.SIZE_OF_INT + 2 * SizeOf.SIZE_OF_DOUBLE * Math.min(count, samples.length) + // samples, weights + SizeOf.SIZE_OF_DOUBLE; // totalPopulationWeight; } public long estimatedInMemorySize() @@ -209,4 +211,9 @@ public long estimatedInMemorySize() SizeOf.sizeOf(samples) + SizeOf.sizeOf(weights); } + + public double getTotalPopulationWeight() + { + return totalPopulationWeight; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/AbstractTestReservoirAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/AbstractTestReservoirAggregation.java index e2f75735375e8..e90aa2a586620 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/AbstractTestReservoirAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/AbstractTestReservoirAggregation.java @@ -15,7 +15,7 @@ import com.facebook.presto.operator.aggregation.AbstractTestAggregationFunction; -import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateFromSamples; +import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateFromSamplesUsingVasicek; import static org.testng.Assert.assertTrue; abstract class AbstractTestReservoirAggregation @@ -32,12 +32,12 @@ protected String getFunctionName() @Override public Double getExpectedValue(int start, int length) { - assertTrue(length < MAX_SAMPLES); + assertTrue(2 * length < MAX_SAMPLES); double[] samples = new double[2 * length]; for (int i = 0; i < length; i++) { samples[i] = (double) (start + i); samples[i + length] = (double) (start + i); } - return calculateFromSamples(samples); + return calculateFromSamplesUsingVasicek(samples); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/AbstractTestStateStrategy.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/AbstractTestStateStrategy.java index b44fd00b5887c..c690f0d15ac37 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/AbstractTestStateStrategy.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/AbstractTestStateStrategy.java @@ -26,10 +26,14 @@ abstract class AbstractTestStateStrategy protected static final double MAX = 10.0; private final Function strategySupplier; + private final boolean weighted; - protected AbstractTestStateStrategy(Function strategySupplier) + protected AbstractTestStateStrategy( + Function strategySupplier, + boolean weighted) { this.strategySupplier = strategySupplier; + this.weighted = weighted; } @Test @@ -38,7 +42,13 @@ public void testUniformDistribution() DifferentialEntropyStateStrategy strategy = strategySupplier.apply(2000); Random random = new Random(13); for (int i = 0; i < 9_999_999; i++) { - strategy.add(10 * random.nextFloat(), 1.0); + double value = 10 * random.nextFloat(); + if (weighted) { + strategy.add(value, 1.0); + } + else { + strategy.add(value); + } } double expected = Math.log(10) / Math.log(2); assertEquals(strategy.calculateEntropy(), expected, 0.1); @@ -51,7 +61,13 @@ public void testNormalDistribution() Random random = new Random(13); double sigma = 0.5; for (int i = 0; i < 9_999_999; i++) { - strategy.add(5 + sigma * random.nextGaussian(), 1.0); + double value = 5 + sigma * random.nextGaussian(); + if (weighted) { + strategy.add(value, 1.0); + } + else { + strategy.add(value); + } } double expected = 0.5 * Math.log(2 * Math.PI * Math.E * sigma * sigma) / Math.log(2); assertEquals(strategy.calculateEntropy(), expected, 0.02); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestEntropyCalculations.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestEntropyCalculations.java index 23f502587d34e..13b8b5b090cd1 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestEntropyCalculations.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestEntropyCalculations.java @@ -17,7 +17,7 @@ import java.util.Random; -import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateFromSamples; +import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateFromSamplesUsingVasicek; import static org.testng.Assert.assertEquals; public class TestEntropyCalculations @@ -30,7 +30,7 @@ public void testUniformDistribution() for (int i = 0; i < samples.length; i++) { samples[i] = random.nextDouble(); } - assertEquals(calculateFromSamples(samples), 0, 0.02); + assertEquals(calculateFromSamplesUsingVasicek(samples), 0, 0.02); } @Test @@ -43,6 +43,6 @@ public void testNormalDistribution() samples[i] = 5 + sigma * random.nextGaussian(); } double expected = 0.5 * Math.log(2 * Math.PI * Math.E * sigma * sigma) / Math.log(2); - assertEquals(calculateFromSamples(samples), expected, 0.02); + assertEquals(calculateFromSamplesUsingVasicek(samples), expected, 0.02); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramJacknifeAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramJacknifeAggregation.java index 9ded83efee00b..5f5ddd67e6b9f 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramJacknifeAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramJacknifeAggregation.java @@ -23,13 +23,14 @@ import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.aggregation; +import static com.facebook.presto.operator.aggregation.differentialentropy.DifferentialEntropyStateStrategy.FIXED_HISTOGRAM_JACKNIFE_METHOD_NAME; public class TestFixedHistogramJacknifeAggregation extends AbstractTestFixedHistogramAggregation { public TestFixedHistogramJacknifeAggregation() { - super(DifferentialEntropyAggregation.FIXED_HISTOGRAM_JACKNIFE_METHOD_NAME); + super(FIXED_HISTOGRAM_JACKNIFE_METHOD_NAME); } @Test( diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramJacknifeStateStrategy.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramJacknifeStateStrategy.java index 4fe9d3a45334f..97347d5d92197 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramJacknifeStateStrategy.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramJacknifeStateStrategy.java @@ -18,6 +18,6 @@ public class TestFixedHistogramJacknifeStateStrategy { public TestFixedHistogramJacknifeStateStrategy() { - super(size -> new FixedHistogramJacknifeStateStrategy(size, AbstractTestStateStrategy.MIN, AbstractTestStateStrategy.MAX)); + super(size -> new FixedHistogramJacknifeStateStrategy(size, AbstractTestStateStrategy.MIN, AbstractTestStateStrategy.MAX), true); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramMleAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramMleAggregation.java index 0f2332dea7b8a..e822efd031abe 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramMleAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramMleAggregation.java @@ -23,13 +23,14 @@ import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.aggregation; +import static com.facebook.presto.operator.aggregation.differentialentropy.DifferentialEntropyStateStrategy.FIXED_HISTOGRAM_MLE_METHOD_NAME; public class TestFixedHistogramMleAggregation extends AbstractTestFixedHistogramAggregation { public TestFixedHistogramMleAggregation() { - super(DifferentialEntropyAggregation.FIXED_HISTOGRAM_MLE_METHOD_NAME); + super(FIXED_HISTOGRAM_MLE_METHOD_NAME); } @Test( diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramMleStateStrategy.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramMleStateStrategy.java index abb538587298a..3075a24c1eff9 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramMleStateStrategy.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestFixedHistogramMleStateStrategy.java @@ -18,6 +18,6 @@ public class TestFixedHistogramMleStateStrategy { public TestFixedHistogramMleStateStrategy() { - super(bucketCount -> new FixedHistogramMleStateStrategy(bucketCount, MIN, MAX)); + super(bucketCount -> new FixedHistogramMleStateStrategy(bucketCount, MIN, MAX), true); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestIllegalMethodAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestIllegalMethodAggregation.java index 8cbc2367244dd..ba071b0686fe5 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestIllegalMethodAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestIllegalMethodAggregation.java @@ -59,7 +59,6 @@ public void testNullMethod() "differential_entropy", fromTypes(BIGINT, DOUBLE, DOUBLE, VARCHAR, DOUBLE, DOUBLE))); createStringsBlock((String) null); - System.out.println("foo"); aggregation( function, createLongsBlock(200), diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestUnweightedReservoirSampleStateStrategy.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestUnweightedReservoirSampleStateStrategy.java index 8f75ba282e4d4..bc0b0c7a4e634 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestUnweightedReservoirSampleStateStrategy.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestUnweightedReservoirSampleStateStrategy.java @@ -18,6 +18,6 @@ public class TestUnweightedReservoirSampleStateStrategy { public TestUnweightedReservoirSampleStateStrategy() { - super(size -> new UnweightedReservoirSampleStateStrategy(size)); + super(size -> new UnweightedReservoirSampleStateStrategy(size), false); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestWeightedReservoirSampleStateStrategy.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestWeightedReservoirSampleStateStrategy.java index 61126b12513fe..e8f76b443bc00 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestWeightedReservoirSampleStateStrategy.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestWeightedReservoirSampleStateStrategy.java @@ -18,6 +18,6 @@ public class TestWeightedReservoirSampleStateStrategy { public TestWeightedReservoirSampleStateStrategy() { - super(size -> new WeightedReservoirSampleStateStrategy(size)); + super(size -> new WeightedReservoirSampleStateStrategy(size), true); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestUnweightedDoubleReservoirSample.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestUnweightedDoubleReservoirSample.java index 645556564faa3..311c85053fad2 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestUnweightedDoubleReservoirSample.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestUnweightedDoubleReservoirSample.java @@ -34,9 +34,10 @@ public void testIllegalMaxSamples() @Test public void testGetMaxSamples() { - UnweightedDoubleReservoirSample sample = new UnweightedDoubleReservoirSample(200); + UnweightedDoubleReservoirSample reservoir = new UnweightedDoubleReservoirSample(200); - assertEquals(sample.getMaxSamples(), 200); + assertEquals(reservoir.getMaxSamples(), 200); + assertEquals(reservoir.getTotalPopulationCount(), 0); } @Test @@ -49,6 +50,7 @@ public void testFew() reservoir.add(3.0); assertEquals(Arrays.stream(reservoir.getSamples()).sorted().toArray(), new double[] {1.0, 2.0, 3.0}); + assertEquals(reservoir.getTotalPopulationCount(), 3); } @Test @@ -58,6 +60,7 @@ public void testMany() long streamLength = 1_000_000; for (int i = 0; i < streamLength; ++i) { + assertEquals(reservoir.getTotalPopulationCount(), i); reservoir.add(i); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestWeightedDoubleReservoirSample.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestWeightedDoubleReservoirSample.java index 3fe2df8d1e7a4..975063e1c1c31 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestWeightedDoubleReservoirSample.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestWeightedDoubleReservoirSample.java @@ -34,9 +34,10 @@ public void testIllegalMaxSamples() @Test public void testGetters() { - WeightedDoubleReservoirSample sample = new WeightedDoubleReservoirSample(200); + WeightedDoubleReservoirSample reservoir = new WeightedDoubleReservoirSample(200); - assertEquals(sample.getMaxSamples(), 200); + assertEquals(reservoir.getMaxSamples(), 200); + assertEquals(reservoir.getTotalPopulationWeight(), 0.0); } @Test @@ -49,6 +50,7 @@ public void testFew() reservoir.add(3.0, 0.5); assertEquals(Arrays.stream(reservoir.getSamples()).sorted().toArray(), new double[] {1.0, 2.0, 3.0}); + assertEquals(reservoir.getTotalPopulationWeight(), 2.5); } @Test @@ -58,6 +60,7 @@ public void testMany() long streamLength = 1_000_000; for (int i = 0; i < streamLength; ++i) { + assertEquals(reservoir.getTotalPopulationWeight(), i, 0.0001); reservoir.add(i, 1.0); } @@ -81,8 +84,10 @@ public void testManyWeighted() WeightedDoubleReservoirSample reservoir = new WeightedDoubleReservoirSample(200); long streamLength = 1_000_000; + double epsilon = 0.00000001; for (int i = 0; i < streamLength; ++i) { - reservoir.add(3, 0.00000001); + assertEquals(reservoir.getTotalPopulationWeight(), epsilon * i, epsilon / 100); + reservoir.add(3, epsilon); } for (int i = 0; i < streamLength; ++i) { reservoir.add(i, 9999999999.0);