From 6f47a4ce79e9ee30376f619667ceee39bd792a21 Mon Sep 17 00:00:00 2001 From: Zac Blanco Date: Tue, 21 Nov 2023 17:13:45 +0800 Subject: [PATCH] Add reservoir_sample aggregation function This commit introduces a new `reservoir_sample` aggregate function which, as opposed to the existing TABLESAMPLE operator lets users pick a fixed sample size. The fixed sample sizes lets users create samples of a known total size while guaranteeing every record has an equal probability of being chosen. Co-authored-by: xiz675 <32505316+xiz675@users.noreply.github.com> --- .../src/main/sphinx/functions/aggregate.rst | 113 +++++++ presto-docs/src/main/sphinx/sql/select.rst | 1 + ...uiltInTypeAndFunctionNamespaceManager.java | 4 +- .../UnweightedDoubleReservoirSample.java | 2 +- ...nweightedReservoirSampleStateStrategy.java | 1 - .../WeightedDoubleReservoirSample.java | 2 +- .../WeightedReservoirSampleStateStrategy.java | 1 - .../GroupedReservoirSampleState.java | 60 ++++ .../reservoirsample/ReservoirSample.java | 290 ++++++++++++++++++ .../ReservoirSampleFunction.java | 120 ++++++++ .../reservoirsample/ReservoirSampleState.java | 26 ++ .../ReservoirSampleStateFactory.java | 53 ++++ .../ReservoirSampleStateSerializer.java | 86 ++++++ .../SingleReservoirSampleState.java | 45 +++ .../aggregation/AggregationTestUtils.java | 8 + .../TestUnweightedDoubleReservoirSample.java | 2 +- .../TestWeightedDoubleReservoirSample.java | 2 +- .../TestReservoirSampleAggregation.java | 275 +++++++++++++++++ 18 files changed, 1084 insertions(+), 7 deletions(-) rename presto-main/src/main/java/com/facebook/presto/operator/aggregation/{reservoirsample => differentialentropy}/UnweightedDoubleReservoirSample.java (98%) rename presto-main/src/main/java/com/facebook/presto/operator/aggregation/{reservoirsample => differentialentropy}/WeightedDoubleReservoirSample.java (98%) create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/GroupedReservoirSampleState.java create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSample.java create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleFunction.java create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleState.java create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleStateFactory.java create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleStateSerializer.java create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/SingleReservoirSampleState.java rename presto-main/src/test/java/com/facebook/presto/operator/aggregation/{reservoirsample => differentialentropy}/TestUnweightedDoubleReservoirSample.java (97%) rename presto-main/src/test/java/com/facebook/presto/operator/aggregation/{reservoirsample => differentialentropy}/TestWeightedDoubleReservoirSample.java (98%) create mode 100644 presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestReservoirSampleAggregation.java diff --git a/presto-docs/src/main/sphinx/functions/aggregate.rst b/presto-docs/src/main/sphinx/functions/aggregate.rst index 58ca477740b14..b04db2091be5e 100644 --- a/presto-docs/src/main/sphinx/functions/aggregate.rst +++ b/presto-docs/src/main/sphinx/functions/aggregate.rst @@ -963,6 +963,117 @@ where :math:`f(x)` is the partial density function of :math:`x`. The function uses the stream summary data structure proposed in the paper `Efficient computation of frequent and top-k elements in data streams `_ by A.Metwally, D.Agrawal and A.Abbadi. +Reservoir Sample Functions +------------------------------- + +Reservoir sample functions use a fixed sample size, as opposed to +:ref:`TABLESAMPLE `. Fixed sample sizes always result in a +fixed total size while still guaranteeing that each record in dataset has an +equal probability of being chosen. See [Vitter1985]_. + +.. function:: reservoir_sample(initial_sample: array(T), initial_processed_count: bigint, values_to_sample: T, desired_sample_size: int) -> row(processed_count: bigint, sample: array(T)) + + Computes a new reservoir sample given: + + - ``initial_sample``: an initial sample array, or ``NULL`` if creating a new + sample. + - ``initial_processed_count``: the number of records processed to generate + the initial sample array. This should be 0 or ``NULL`` if + ``initital_sample`` is ``NULL``. + - ``values_to_sample``: the column to sample from. + - ``desired_sample_size``: the size of reservoir sample. + + The function outputs a single row type with two columns: + + #. Processed count: The total number of rows the function sampled + from. It includes the total from the ``initial_processed_count``, + if provided. + + #. Reservoir sample: An array with length equivalent to the minimum of + ``desired_sample_size`` and the number of values in the + ``values_to_sample`` argument. + + + .. code-block:: sql + + WITH result as ( + SELECT + reservoir_sample(NULL, 0, col, 5) as reservoir + FROM ( + VALUES + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 + ) as t(col) + ) + SELECT + reservoir.processed_count, reservoir.sample + FROM result; + + .. code-block:: none + + processed_count | sample + -----------------+----------------- + 10 | [1, 2, 8, 4, 5] + + To merge older samples with new data, supply valid arguments to the + ``initial_sample`` argument and ``initial_processed_count`` arguments. + + .. code-block:: sql + + WITH initial_sample as ( + SELECT + reservoir_sample(NULL, 0, col, 3) as reservoir + FROM ( + VALUES + 0, 1, 2, 3, 4 + ) as t(col) + ), + new_sample as ( + SELECT + reservoir_sample( + (SELECT reservoir.sample FROM initial_sample), + (SELECT reservoir.processed_count FROM initial_sample), + col, + 3 + ) as result + FROM ( + VALUES + 5, 6, 7, 8, 9 + ) as t(col) + ) + SELECT + result.processed_count, result.sample + FROM new_sample; + + .. code-block:: none + + processed_count | sample + -----------------+----------- + 10 | [8, 3, 2] + + To sample an entire row of a table, use a ``ROW`` type input with + each subfield corresponding to the columns of the source table. + + .. code-block:: sql + + WITH result as ( + SELECT + reservoir_sample(NULL, 0, CAST(row(idx, val) AS row(idx int, val varchar)), 2) as reservoir + FROM ( + VALUES + (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e') + ) as t(idx, val) + ) + SELECT + reservoir.processed_count, reservoir.sample + FROM result; + + .. code-block:: none + + processed_count | sample + -----------------+---------------------------------- + 5 | [{idx=1, val=a}, {idx=5, val=e}] + + --------------------------- @@ -978,3 +1089,5 @@ where :math:`f(x)` is the partial density function of :math:`x`. .. [Efraimidis2006] Efraimidis, Pavlos S.; Spirakis, Paul G. (2006-03-16). "Weighted random sampling with a reservoir". Information Processing Letters. 97 (5): 181–185. + +.. [Vitter1985] Vitter, Jeffrey S. "Random sampling with a reservoir." ACM Transactions on Mathematical Software (TOMS) 11.1 (1985): 37-57. diff --git a/presto-docs/src/main/sphinx/sql/select.rst b/presto-docs/src/main/sphinx/sql/select.rst index 9baab71c6fc6a..11c5c2202129f 100644 --- a/presto-docs/src/main/sphinx/sql/select.rst +++ b/presto-docs/src/main/sphinx/sql/select.rst @@ -655,6 +655,7 @@ after the ``OFFSET`` clause:: 4 (2 rows) +.. _sql-tablesample: TABLESAMPLE ----------- diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index d962af1672f85..18b4f980de5d9 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -103,6 +103,7 @@ import com.facebook.presto.operator.aggregation.noisyaggregation.NoisyApproximateSetSfmAggregationDefaultPrecision; import com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountIfGaussianAggregation; import com.facebook.presto.operator.aggregation.noisyaggregation.SfmSketchMergeAggregation; +import com.facebook.presto.operator.aggregation.reservoirsample.ReservoirSampleFunction; import com.facebook.presto.operator.scalar.ArrayAllMatchFunction; import com.facebook.presto.operator.scalar.ArrayAnyMatchFunction; import com.facebook.presto.operator.scalar.ArrayCardinalityFunction; @@ -970,7 +971,8 @@ private List getBuildInFunctions(FeaturesConfig featuresC .function(DISTINCT_TYPE_DISTINCT_FROM_OPERATOR) .functions(DISTINCT_TYPE_HASH_CODE_OPERATOR, DISTINCT_TYPE_XX_HASH_64_OPERATOR) .function(DISTINCT_TYPE_INDETERMINATE_OPERATOR) - .codegenScalars(MapFilterFunction.class); + .codegenScalars(MapFilterFunction.class) + .aggregate(ReservoirSampleFunction.class); switch (featuresConfig.getRegexLibrary()) { case JONI: diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/UnweightedDoubleReservoirSample.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedDoubleReservoirSample.java similarity index 98% rename from presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/UnweightedDoubleReservoirSample.java rename to presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedDoubleReservoirSample.java index 57d527db14a27..d9c7ea87dbc9b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/UnweightedDoubleReservoirSample.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedDoubleReservoirSample.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.operator.aggregation.reservoirsample; +package com.facebook.presto.operator.aggregation.differentialentropy; import io.airlift.slice.SizeOf; import io.airlift.slice.SliceInput; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedReservoirSampleStateStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedReservoirSampleStateStrategy.java index b1ab327501783..908612b3904f9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedReservoirSampleStateStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/UnweightedReservoirSampleStateStrategy.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator.aggregation.differentialentropy; -import com.facebook.presto.operator.aggregation.reservoirsample.UnweightedDoubleReservoirSample; import com.facebook.presto.spi.PrestoException; import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/WeightedDoubleReservoirSample.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedDoubleReservoirSample.java similarity index 98% rename from presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/WeightedDoubleReservoirSample.java rename to presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedDoubleReservoirSample.java index 102d1bef05155..8be48379d0f71 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/WeightedDoubleReservoirSample.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedDoubleReservoirSample.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.operator.aggregation.reservoirsample; +package com.facebook.presto.operator.aggregation.differentialentropy; import io.airlift.slice.SizeOf; import io.airlift.slice.SliceInput; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedReservoirSampleStateStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedReservoirSampleStateStrategy.java index e65b81ef845df..d165188c39b22 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedReservoirSampleStateStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/differentialentropy/WeightedReservoirSampleStateStrategy.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator.aggregation.differentialentropy; -import com.facebook.presto.operator.aggregation.reservoirsample.WeightedDoubleReservoirSample; import com.facebook.presto.spi.PrestoException; import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/GroupedReservoirSampleState.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/GroupedReservoirSampleState.java new file mode 100644 index 0000000000000..2dec0c09d0782 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/GroupedReservoirSampleState.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.reservoirsample; + +import com.facebook.presto.common.array.ObjectBigArray; +import com.facebook.presto.operator.aggregation.state.AbstractGroupedAccumulatorState; +import org.openjdk.jol.info.ClassLayout; + +import static java.util.Objects.requireNonNull; + +public class GroupedReservoirSampleState + extends AbstractGroupedAccumulatorState + implements ReservoirSampleState +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedReservoirSampleState.class).instanceSize(); + private final ObjectBigArray samples = new ObjectBigArray<>(); + private long size; + + @Override + public long getEstimatedSize() + { + return INSTANCE_SIZE + size + samples.sizeOf(); + } + + @Override + public void ensureCapacity(long size) + { + samples.ensureCapacity(size); + } + + @Override + public ReservoirSample get() + { + return samples.get(getGroupId()); + } + + @Override + public void set(ReservoirSample value) + { + requireNonNull(value, "value is null"); + ReservoirSample previous = get(); + if (previous != null) { + size -= previous.estimatedInMemorySize(); + } + + samples.set(getGroupId(), value); + size += value.estimatedInMemorySize(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSample.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSample.java new file mode 100644 index 0000000000000..8728f01f6b084 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSample.java @@ -0,0 +1,290 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.reservoirsample; + +import com.facebook.presto.common.block.ArrayBlock; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.Type; +import io.airlift.slice.SizeOf; +import org.openjdk.jol.info.ClassLayout; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Math.max; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ReservoirSample +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleReservoirSampleState.class).instanceSize(); + private final Type type; + + public Type getArrayType() + { + return arrayType; + } + + private final Type arrayType; + /** + * Represents the list of sampled values. + *
+ * We use an {@link ArrayList} instead of {@link Block} because the + * algorithm that generates reservoir samples requires shuffling of elements + * in the reservoir. + *
+ * The {@link Block} interface doesn't have any method for setting values at + * arbitrary positions, so we resort to internally representing the sample + * as a list and then combining the samples into a single block later. + */ + private ArrayList samples; + private int maxSampleSize = -1; + private long processedCount; + + private Block initialSample; + + public Block getInitialSample() + { + return initialSample; + } + + public long getInitialProcessedCount() + { + return initialProcessedCount; + } + + private long initialProcessedCount = -1; + + public ReservoirSample(Type type) + { + this.type = requireNonNull(type, "type is null"); + this.arrayType = new ArrayType(type); + this.samples = new ArrayList<>(); + } + + protected ReservoirSample(Type type, long processedCount, int maxSampleSize, Block samples, Block initialSample, long initialSeenCount) + { + this.type = requireNonNull(type, "type is null"); + this.arrayType = new ArrayType(type); + this.processedCount = processedCount; + this.samples = blockToList(samples); + this.maxSampleSize = maxSampleSize; + initializeInitialSample(initialSample, initialSeenCount); + } + + private static ArrayList blockToList(Block inputBlock) + { + // sometimes single values such as bigint/double are serialized as + // LongArrayBlock which don't implement the Block::getBlock function. + // ArrayBlock::getSingleValueBlock returns another ArrayBlock of size 1, whereas + // we need to extract the internal block rather than have an array + Function extractor = inputBlock instanceof ArrayBlock ? inputBlock::getBlock : inputBlock::getSingleValueBlock; + return IntStream.range(0, inputBlock.getPositionCount()) + .mapToObj(extractor::apply) + .collect(Collectors.toCollection(ArrayList::new)); + } + + private static ArrayList mergeBlockSamples(ArrayList samples1, ArrayList samples2, long seenCount1, long seenCount2) + { + int nextIndex = 0; + int otherNextIndex = 0; + ArrayList merged = new ArrayList<>(samples1.size()); + for (int i = 0; i < samples1.size(); i++) { + if (ThreadLocalRandom.current().nextLong(0, seenCount1 + seenCount2) < seenCount1) { + merged.add(samples1.get(nextIndex++)); + } + else { + merged.add(samples2.get(otherNextIndex++)); + } + } + return merged; + } + + public void tryInitialize(int n) + { + if (sampleNotInitialized()) { + samples = new ArrayList<>(max(n, 0)); + maxSampleSize = n; + } + } + + public void initializeInitialSample(@Nullable Block initialSample, long initialProcessedCount) + { + if (this.initialProcessedCount < 0) { + if (initialSample != null && initialSample.getPositionCount() > 0) { + checkArgument(initialProcessedCount >= initialSample.getPositionCount(), + "initialProcessedCount must be greater than or equal " + + "to the number of positions in the initial sample"); + } + this.initialSample = initialSample; + this.initialProcessedCount = initialProcessedCount; + } + } + + public void mergeWith(@Nullable ReservoirSample other) + { + if (other == null) { + return; + } + merge(other); + initializeInitialSample(other.initialSample, other.initialProcessedCount); + } + + private boolean sampleNotInitialized() + { + return maxSampleSize < 0 || samples == null; + } + + public int getSampleSize() + { + if (sampleNotInitialized()) { + return 0; + } + return samples.size(); + } + + public int getMaxSampleSize() + { + return maxSampleSize; + } + + /** + * Potentially add a value from a block at a given position into the sample. + * + * @param block the block containing the potential sample + * @param position the position in the block to potentially insert + */ + public void add(Block block, int position) + { + if (sampleNotInitialized()) { + throw new IllegalArgumentException("reservoir sample not properly initialized"); + } + processedCount++; + int sampleSize = getMaxSampleSize(); + if (processedCount <= sampleSize) { + BlockBuilder sampleBlock = type.createBlockBuilder(null, 1); + type.appendTo(block, position, sampleBlock); + samples.add(sampleBlock.build()); + } + else { + long index = ThreadLocalRandom.current().nextLong(0, processedCount); + if (index < samples.size()) { + BlockBuilder sampleBlock = type.createBlockBuilder(null, 1); + type.appendTo(block, position, sampleBlock); + samples.set((int) index, sampleBlock.build()); + } + } + } + + private void addSingleBlock(Block block) + { + processedCount++; + int sampleSize = getMaxSampleSize(); + if (processedCount <= sampleSize) { + samples.add(block); + } + else { + long index = ThreadLocalRandom.current().nextLong(0L, processedCount); + if (index < samples.size()) { + samples.set((int) index, block); + } + } + } + + public void merge(ReservoirSample other) + { + if (sampleNotInitialized()) { + tryInitialize(other.getMaxSampleSize()); + } + if (other.sampleNotInitialized()) { + return; + } + checkArgument( + getMaxSampleSize() == other.getMaxSampleSize(), + format("maximum number of samples %s must be equal to that of other %s", getMaxSampleSize(), other.getMaxSampleSize())); + if (other.processedCount < getMaxSampleSize()) { + for (int i = 0; i < other.samples.size(); i++) { + addSingleBlock(other.samples.get(i)); + } + return; + } + if (processedCount < getMaxSampleSize()) { + for (int i = 0; i < processedCount; i++) { + other.addSingleBlock(samples.get(i)); + } + processedCount = other.processedCount; + samples = other.samples; + return; + } + Collections.shuffle(samples); + Collections.shuffle(other.samples); + samples = mergeBlockSamples(samples, other.samples, processedCount, other.processedCount); + processedCount += other.processedCount; + } + + public Type getType() + { + return type; + } + + public long getProcessedCount() + { + return processedCount; + } + + public long estimatedInMemorySize() + { + return INSTANCE_SIZE + + (initialSample != null ? initialSample.getSizeInBytes() : 0) + + SizeOf.sizeOfObjectArray(samples.size()); + } + + public void serialize(BlockBuilder out) + { + BlockBuilder sampleBlock = getSampleBlockBuilder(); + if (initialSample == null) { + out.appendNull(); + } + else { + out.appendStructure(initialSample); + } + BIGINT.writeLong(out, initialProcessedCount); + BIGINT.writeLong(out, processedCount); + INTEGER.writeLong(out, maxSampleSize); + arrayType.appendTo(sampleBlock.build(), 0, out); + } + + BlockBuilder getSampleBlockBuilder() + { + int sampleSize = getSampleSize(); + BlockBuilder sampleBlock = arrayType.createBlockBuilder(null, sampleSize); + BlockBuilder sampleEntryBuilder = sampleBlock.beginBlockEntry(); + for (int i = 0; i < sampleSize; i++) { + type.appendTo(samples.get(i), 0, sampleEntryBuilder); + } + sampleBlock.closeEntry(); + return sampleBlock; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleFunction.java new file mode 100644 index 0000000000000..382f18a90e995 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleFunction.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.reservoirsample; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.operator.aggregation.NullablePosition; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +import java.util.Optional; + +import static com.facebook.presto.common.type.StandardTypes.BIGINT; +import static com.facebook.presto.common.type.StandardTypes.INTEGER; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Math.max; + +@AggregationFunction(value = "reservoir_sample", isCalledOnNullInput = true) +@Description("Generates a fixed-size bernoulli sample from the input column. Will merge an existing sample into the newly-generated sample.") +public class ReservoirSampleFunction +{ + public static final String NAME = "reservoir_sample"; + + private ReservoirSampleFunction() + { + } + + @InputFunction + @TypeParameter("T") + public static void input( + @TypeParameter("T") Type type, + @AggregationState ReservoirSampleState state, + @BlockPosition @SqlType("array(T)") @NullablePosition Block initialState, + @BlockIndex int initialStatePos, + @SqlType(BIGINT) long initialProcessedCount, + @BlockPosition @SqlType("T") @NullablePosition Block value, + @BlockIndex int position, + @SqlType(INTEGER) long desiredSampleSize) + { + checkArgument(desiredSampleSize > 0, "desired sample size must be > 0"); + if (initialProcessedCount <= 0) { + // initial state block must be null or empty to prevent confusing situation where the + // initial sample is not used + checkArgument(initialState.isNull(initialStatePos) || initialState.getBlock(initialStatePos).getPositionCount() == 0, "initial state array must be null or empty when initial processed count is <= 0"); + } + if (state.get() == null) { + state.set(new ReservoirSample(type)); + } + ReservoirSample sample = state.get(); + sample.tryInitialize((int) desiredSampleSize); + + Block initialStateBlock = null; + if (initialProcessedCount > 0) { + initialStateBlock = initialState.getBlock(initialStatePos); + } + sample.initializeInitialSample(initialStateBlock, initialProcessedCount); + sample.add(value, position); + } + + @CombineFunction + public static void combine( + @AggregationState ReservoirSampleState state, + @AggregationState ReservoirSampleState otherState) + { + if (state.get() == null) { + state.set(otherState.get()); + return; + } + state.get().mergeWith(otherState.get()); + } + + @OutputFunction("row(processed_count bigint, sample array(T))") + public static void output( + @TypeParameter("T") Type elementType, + @AggregationState ReservoirSampleState state, + BlockBuilder out) + { + ReservoirSample reservoirSample = state.get(); + final Block initialSampleBlock = Optional.ofNullable(reservoirSample.getInitialSample()) + .orElseGet(() -> RunLengthEncodedBlock.create(elementType, null, 0)); + long initialProcessedCount = reservoirSample.getInitialProcessedCount(); + // merge the final state with the initial state given + checkArgument(!(initialProcessedCount != -1 && + initialProcessedCount != initialSampleBlock.getPositionCount()) || + reservoirSample.getMaxSampleSize() == initialSampleBlock.getPositionCount(), + "when a positive initial_processed_count is provided the size of " + + "the initial sample must be equal to desired_sample_size parameter"); + ReservoirSample finalSample = new ReservoirSample(elementType, max(initialProcessedCount, 0), reservoirSample.getMaxSampleSize(), initialSampleBlock, null, 0); + finalSample.merge(reservoirSample); + + long count = finalSample.getProcessedCount(); + BlockBuilder entryBuilder = out.beginBlockEntry(); + BigintType.BIGINT.writeLong(entryBuilder, count); + BlockBuilder sampleBlock = finalSample.getSampleBlockBuilder(); + reservoirSample.getArrayType().appendTo(sampleBlock.build(), 0, entryBuilder); + out.closeEntry(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleState.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleState.java new file mode 100644 index 0000000000000..d7be04f563686 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleState.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.reservoirsample; + +import com.facebook.presto.spi.function.AccumulatorState; +import com.facebook.presto.spi.function.AccumulatorStateMetadata; + +@AccumulatorStateMetadata(stateSerializerClass = ReservoirSampleStateSerializer.class, stateFactoryClass = ReservoirSampleStateFactory.class) +public interface ReservoirSampleState + extends AccumulatorState +{ + ReservoirSample get(); + + void set(ReservoirSample value); +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleStateFactory.java new file mode 100644 index 0000000000000..156813a4e98c7 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleStateFactory.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.reservoirsample; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.AccumulatorStateFactory; +import com.facebook.presto.spi.function.TypeParameter; + +public class ReservoirSampleStateFactory + implements AccumulatorStateFactory +{ + private final Type type; + + public ReservoirSampleStateFactory(@TypeParameter("T") Type type) + { + this.type = type; + } + + @Override + public ReservoirSampleState createSingleState() + { + return new SingleReservoirSampleState(type); + } + + @Override + public Class getSingleStateClass() + { + return SingleReservoirSampleState.class; + } + + @Override + public ReservoirSampleState createGroupedState() + { + return new GroupedReservoirSampleState(); + } + + @Override + public Class getGroupedStateClass() + { + return GroupedReservoirSampleState.class; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleStateSerializer.java new file mode 100644 index 0000000000000..f76f713f1acfc --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSampleStateSerializer.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.reservoirsample; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.RowType.Field; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.function.TypeParameter; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; + +public class ReservoirSampleStateSerializer + implements AccumulatorStateSerializer +{ + private final Type elementType; + private final Type arrayType; + + public ReservoirSampleStateSerializer(@TypeParameter("T") Type elementType) + { + this.elementType = elementType; + this.arrayType = new ArrayType(elementType); + } + + @Override + public Type getSerializedType() + { + Field initialSample = new Field(Optional.of("initialSample"), arrayType); + Field initialSeenCount = new Field(Optional.of("initialSeenCount"), BIGINT); + Field seenCount = new Field(Optional.of("seenCount"), BIGINT); + Field maxSampleSize = new Field(Optional.of("maxSampleSize"), INTEGER); + Field sample = new Field(Optional.of("sample"), arrayType); + List fields = Arrays.asList(initialSample, initialSeenCount, seenCount, maxSampleSize, sample); + return RowType.from(fields); + } + + @Override + public void serialize(ReservoirSampleState state, BlockBuilder out) + { + if (state.get() == null) { + out.appendNull(); + } + else { + BlockBuilder entryBuilder = out.beginBlockEntry(); + state.get().serialize(entryBuilder); + out.closeEntry(); + } + } + + @Override + public void deserialize(Block block, int index, ReservoirSampleState state) + { + if (block.isNull(index)) { + state.set(null); + return; + } + Type rowTypes = getSerializedType(); + Block stateBlock = (Block) rowTypes.getObject(block, index); + Block initialSample = (Block) arrayType.getObject(stateBlock, 0); + long initialSeenCount = stateBlock.getLong(1); + long seenCount = stateBlock.getLong(2); + int maxSampleSize = stateBlock.getInt(3); + Block samplesBlock = (Block) arrayType.getObject(stateBlock, 4); + ReservoirSample reservoirSample = new ReservoirSample(elementType, seenCount, maxSampleSize, samplesBlock, initialSample, initialSeenCount); + state.set(reservoirSample); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/SingleReservoirSampleState.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/SingleReservoirSampleState.java new file mode 100644 index 0000000000000..cfbdd92e14ba1 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/SingleReservoirSampleState.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.reservoirsample; + +import com.facebook.presto.common.type.Type; + +public class SingleReservoirSampleState + implements ReservoirSampleState +{ + private ReservoirSample sample; + + public SingleReservoirSampleState(Type type) + { + sample = new ReservoirSample(type); + } + + @Override + public long getEstimatedSize() + { + return sample.estimatedInMemorySize(); + } + + @Override + public ReservoirSample get() + { + return sample; + } + + @Override + public void set(ReservoirSample value) + { + sample = value; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AggregationTestUtils.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AggregationTestUtils.java index 80986976d1421..963178a1d1592 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AggregationTestUtils.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AggregationTestUtils.java @@ -253,6 +253,14 @@ public static Object aggregation(JavaAggregationFunctionImplementation function, return aggregation; } + /** + * Gets the aggregation result with asserting any intermediate states or reversed argument results + */ + public static Object executeAggregation(JavaAggregationFunctionImplementation function, Block... blocks) + { + return aggregation(function, createArgs(function), Optional.empty(), new Page(blocks)); + } + private static Object aggregation(JavaAggregationFunctionImplementation function, int[] args, Optional maskChannel, Page... pages) { Accumulator aggregation = generateAccumulatorFactory(function, Ints.asList(args), maskChannel).createAccumulator(UpdateMemory.NOOP); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestUnweightedDoubleReservoirSample.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestUnweightedDoubleReservoirSample.java similarity index 97% rename from presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestUnweightedDoubleReservoirSample.java rename to presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestUnweightedDoubleReservoirSample.java index 311c85053fad2..91a46f161405d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestUnweightedDoubleReservoirSample.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestUnweightedDoubleReservoirSample.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.operator.aggregation.reservoirsample; +package com.facebook.presto.operator.aggregation.differentialentropy; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestWeightedDoubleReservoirSample.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestWeightedDoubleReservoirSample.java similarity index 98% rename from presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestWeightedDoubleReservoirSample.java rename to presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestWeightedDoubleReservoirSample.java index 975063e1c1c31..c3e98c304240d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestWeightedDoubleReservoirSample.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/differentialentropy/TestWeightedDoubleReservoirSample.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.operator.aggregation.reservoirsample; +package com.facebook.presto.operator.aggregation.differentialentropy; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestReservoirSampleAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestReservoirSampleAggregation.java new file mode 100644 index 0000000000000..370a8cdfe36f6 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/reservoirsample/TestReservoirSampleAggregation.java @@ -0,0 +1,275 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.reservoirsample; + +import com.facebook.presto.Session; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Parameters; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; +import static com.facebook.presto.operator.aggregation.AggregationTestUtils.executeAggregation; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestReservoirSampleAggregation +{ + protected FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + protected Session session = testSessionBuilder().build(); + + @Test + public void testNoInitialSample() + { + assertAggregation( + getDoubleFunction(), + // seen count, and sample + ImmutableList.of(5L, ImmutableList.of(1.0, 1.0)), + // arguments + copyBlock(BIGINT, nullBlock(), 5), + copyBlock(BIGINT, bigintBlock(0), 5), + doubleBlock(1, 1, 1, 1, 1), + copyBlock(INTEGER, intBlock(2), 5)); + } + + @Test + public void testLarge() + { + int sampleSize = 5000; + int inputSize = 15_000; + assertAggregation( + getDoubleFunction(), + // seen count, and sample + ImmutableList.of((long) inputSize, IntStream.range(0, sampleSize).mapToObj(x -> 1.0).collect(Collectors.toList())), + // arguments + copyBlock(BIGINT, nullBlock(), inputSize), + copyBlock(BIGINT, bigintBlock(0), inputSize), + doubleBlock(IntStream.range(0, inputSize).mapToDouble(x -> 1.0).toArray()), + copyBlock(INTEGER, intBlock(sampleSize), inputSize)); + } + + @DataProvider(name = "invalidSampleSize") + public Object[][] invalidSampleParameters() + { + return new Object[][] {{0}, {-1}}; + } + + /** + * Throws exception when desired sample size is <= 0 + */ + @Test(dataProvider = "invalidSampleSize", expectedExceptions = IllegalArgumentException.class) + @Parameters("sampleSize") + public void testInvalidSampleSize(int sampleSize) + { + assertAggregation( + getDoubleFunction(), + // seen count, and sample + ImmutableList.of(-1L, ImmutableList.of(1.0, 1.0)), + // arguments + copyBlock(BIGINT, nullBlock(), 5), + copyBlock(BIGINT, bigintBlock(0), 5), + doubleBlock(1, 1, 1, 1, 1), + copyBlock(INTEGER, intBlock(sampleSize), 5)); + } + + @Test + public void testInitialSampleSameSize() + { + assertAggregation( + getDoubleFunction(), + // seen count, and sample + ImmutableList.of(15L, ImmutableList.of(1.0, 1.0)), + // arguments + // initial sample + arrayOfBlock(DOUBLE, doubleArrayBlock(1.0, 1.0), 5), + // initial sample seen count + copyBlock(BIGINT, bigintBlock(10), 5), + // actual input values + doubleBlock(1, 1, 1, 1, 1), + // sample size + copyBlock(INTEGER, intBlock(2), 5)); + } + + /** + * Throws exception because the initial sample size is not equal to the desired sample size + */ + @Test(expectedExceptions = IllegalArgumentException.class) + public void testInitialSampleWrongSize() + { + assertAggregation( + getDoubleFunction(), + // seen count, and sample + ImmutableList.of(15L, ImmutableList.of(1.0, 1.0)), + // arguments + // initial sample + arrayOfBlock(DOUBLE, doubleArrayBlock(1.0, 1.0, 2.0), 5), + // initial sample seen count + copyBlock(BIGINT, bigintBlock(10), 5), + // actual input values + doubleBlock(1, 1, 1, 1, 1), + // sample size + copyBlock(INTEGER, intBlock(2), 5)); + } + + /** + * valid because when the initial sample was created there could have been less records than + * the desired sample size. + */ + @Test + public void testInitialSampleSmallerThanMaxSize() + { + assertAggregation( + getDoubleFunction(), + // seen count, and sample + ImmutableList.of(6L, ImmutableList.of(1.0, 1.0)), + // arguments + // initial sample + arrayOfBlock(DOUBLE, doubleArrayBlock(1.0), 5), + // initial sample seen count + copyBlock(BIGINT, bigintBlock(1), 5), + // actual input values + doubleBlock(1, 1, 1, 1, 1), + // sample size + copyBlock(INTEGER, intBlock(2), 5)); + } + + /** + * Throws exception because the processed count is less than the size of the initial sample + */ + @Test(expectedExceptions = IllegalArgumentException.class) + public void testInitialSampleSeenCountSmallerThanInitialSample() + { + assertAggregation( + getDoubleFunction(), + // seen count, and sample + ImmutableList.of(6L, ImmutableList.of(1.0, 1.0)), + // arguments + // initial sample + arrayOfBlock(DOUBLE, doubleArrayBlock(1.0, 1.0), 5), + // initial sample seen count + copyBlock(BIGINT, bigintBlock(1), 5), + // actual input values + doubleBlock(1, 1, 1, 1, 1), + // sample size + copyBlock(INTEGER, intBlock(2), 5)); + } + + @Test + public void testValidResults() + { + Object result = executeAggregation( + getDoubleFunction(), + // initial sample + copyBlock(UNKNOWN, nullBlock(), 10), + // initial sample seen count + copyBlock(BIGINT, bigintBlock(0), 10), + // actual input values + doubleBlock(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + // sample size + copyBlock(INTEGER, intBlock(4), 10)); + Set items = IntStream.range(0, 10).boxed().map(Integer::doubleValue).collect(Collectors.toSet()); + assertTrue(result instanceof List); + List resultItems = (List) result; + Long processedCount = (Long) resultItems.get(0); + assertEquals(processedCount, items.size()); + List sample = (List) resultItems.get(1); + assertTrue(items.containsAll(sample)); + } + + private JavaAggregationFunctionImplementation getFunction(Type... arguments) + { + return functionAndTypeManager.getJavaAggregateFunctionImplementation(functionAndTypeManager.lookupFunction("reservoir_sample", fromTypes(arguments))); + } + + private JavaAggregationFunctionImplementation getDoubleFunction() + { + return getFunction(new ArrayType(DOUBLE), BIGINT, DOUBLE, INTEGER); + } + + private static Block bigintBlock(long value) + { + return BIGINT.createBlockBuilder(null, 1).writeLong(value).build(); + } + + private static Block intBlock(int... values) + { + BlockBuilder blockBuilder = INTEGER.createBlockBuilder(null, values.length); + Arrays.stream(values).forEach(value -> { + INTEGER.writeLong(blockBuilder, value); + }); + return blockBuilder.build(); + } + + private static Block doubleBlock(double... values) + { + BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, values.length); + Arrays.stream(values).forEach(value -> { + DOUBLE.writeDouble(blockBuilder, value); + }); + return blockBuilder.build(); + } + + private static Block doubleArrayBlock(double... values) + { + BlockBuilder builder = DOUBLE.createBlockBuilder(null, values.length); + Arrays.stream(values) + .forEach(value -> { + DOUBLE.writeDouble(builder, value); + }); + return builder.build(); + } + + private static Block arrayOfBlock(Type innerType, Block value, int count) + { + Type arrayType = new ArrayType(innerType); + BlockBuilder builder = arrayType.createBlockBuilder(null, count); + for (int i = 0; i < count; i++) { + builder.appendStructure(value); + } + return builder.build(); + } + + private static Block nullBlock() + { + return DOUBLE.createBlockBuilder(null, 1).appendNull().build(); + } + + private static Block copyBlock(Type type, Block value, int positionCount) + { + BlockBuilder builder = type.createBlockBuilder(null, positionCount); + for (int i = 0; i < positionCount; i++) { + type.appendTo(value, 0, builder); + } + return builder.build(); + } +}