Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion presto-docs/src/main/sphinx/functions/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ To find the `ROC curve <https://en.wikipedia.org/wiki/Receiver_operating_charact

The thresholds are defined as a sequence whose :math:`j`-th entry is the :math:`j`-th smallest threshold.


Differential Entropy Functions
-------------------------------

Expand All @@ -572,7 +573,7 @@ That is, for a random variable :math:`x`, they approximate

.. math ::

H(x) = - \int x \log_2\left(f(x)\right) dx,
h(x) = - \int x \log_2\left(f(x)\right) dx,

where :math:`f(x)` is the partial density function of :math:`x`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package com.facebook.presto.operator.aggregation.differentialentropy;

import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
Expand All @@ -23,24 +22,15 @@
import com.facebook.presto.spi.function.OutputFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.type.StandardTypes;
import com.google.common.annotations.VisibleForTesting;
import io.airlift.slice.Slice;

import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;

@AggregationFunction("differential_entropy")
@Description("Computes differential entropy based on random-variable samples")
public final class DifferentialEntropyAggregation
{
@VisibleForTesting
public static final String FIXED_HISTOGRAM_MLE_METHOD_NAME = "fixed_histogram_mle";
@VisibleForTesting
public static final String FIXED_HISTOGRAM_JACKNIFE_METHOD_NAME = "fixed_histogram_jacknife";

private DifferentialEntropyAggregation() {}

@InputFunction
Expand All @@ -53,38 +43,15 @@ public static void input(
@SqlType(StandardTypes.DOUBLE) double min,
@SqlType(StandardTypes.DOUBLE) double max)
{
String requestedMethod = method.toStringUtf8().toLowerCase(ENGLISH);
DifferentialEntropyStateStrategy strategy = state.getStrategy();
if (strategy == null) {
switch (requestedMethod) {
case FIXED_HISTOGRAM_MLE_METHOD_NAME:
strategy = new FixedHistogramMleStateStrategy(size, min, max);
break;
case FIXED_HISTOGRAM_JACKNIFE_METHOD_NAME:
strategy = new FixedHistogramJacknifeStateStrategy(size, min, max);
break;
default:
throw new PrestoException(
INVALID_FUNCTION_ARGUMENT,
format("In differential_entropy UDF, invalid method: %s", requestedMethod));
}
state.setStrategy(strategy);
}
else {
switch (requestedMethod.toLowerCase(ENGLISH)) {
case FIXED_HISTOGRAM_MLE_METHOD_NAME:
verify(strategy instanceof FixedHistogramMleStateStrategy,
format("Strategy class is not compatible with entropy method: %s %s", strategy.getClass().getSimpleName(), method));
break;
case FIXED_HISTOGRAM_JACKNIFE_METHOD_NAME:
verify(strategy instanceof FixedHistogramJacknifeStateStrategy,
format("Strategy class is not compatible with entropy method: %s %s", strategy.getClass().getSimpleName(), method));
break;
default:
verify(false, format("Unknown entropy method %s", method));
}
}
strategy.validateParameters(size, sample, weight, min, max);
DifferentialEntropyStateStrategy strategy = DifferentialEntropyStateStrategy.getStrategy(
state.getStrategy(),
size,
sample,
weight,
method.toStringUtf8().toLowerCase(ENGLISH),
min,
max);
state.setStrategy(strategy);
strategy.add(sample, weight);
}

Expand All @@ -95,16 +62,12 @@ public static void input(
@SqlType(StandardTypes.DOUBLE) double sample,
@SqlType(StandardTypes.DOUBLE) double weight)
{
DifferentialEntropyStateStrategy strategy = state.getStrategy();
if (state.getStrategy() == null) {
strategy = new WeightedReservoirSampleStateStrategy(size);
state.setStrategy(strategy);
}
else {
verify(strategy instanceof WeightedReservoirSampleStateStrategy,
format("Expected WeightedReservoirSampleStateStrategy, got: %s", strategy.getClass().getSimpleName()));
}
strategy.validateParameters(size, sample, weight);
DifferentialEntropyStateStrategy strategy = DifferentialEntropyStateStrategy.getStrategy(
state.getStrategy(),
size,
sample,
weight);
state.setStrategy(strategy);
strategy.add(sample, weight);
}

Expand All @@ -114,17 +77,12 @@ public static void input(
@SqlType(StandardTypes.BIGINT) long size,
@SqlType(StandardTypes.DOUBLE) double sample)
{
DifferentialEntropyStateStrategy strategy = state.getStrategy();
if (state.getStrategy() == null) {
strategy = new UnweightedReservoirSampleStateStrategy(size);
state.setStrategy(strategy);
}
else {
verify(strategy instanceof UnweightedReservoirSampleStateStrategy,
format("Expected UnweightedReservoirSampleStateStrategy, got: %s", strategy.getClass().getSimpleName()));
}
strategy.validateParameters(size, sample);
strategy.add(sample, 1.0);
DifferentialEntropyStateStrategy strategy = DifferentialEntropyStateStrategy.getStrategy(
state.getStrategy(),
size,
sample);
state.setStrategy(strategy);
strategy.add(sample);
}

@CombineFunction
Expand All @@ -141,11 +99,7 @@ public static void combine(
if (otherStrategy == null) {
return;
}

verify(strategy.getClass() == otherStrategy.getClass(),
format("In combine, %s != %s", strategy.getClass().getSimpleName(), otherStrategy.getClass().getSimpleName()));

strategy.mergeWith(otherStrategy);
DifferentialEntropyStateStrategy.combine(strategy, otherStrategy);
}

@OutputFunction("double")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ public DifferentialEntropyStateStrategy getStrategy()
@Override
public long getEstimatedSize()
{
if (strategy == null) {
return 0;
}
return strategy.getEstimatedSize();
return strategy == null ? 0 : strategy.getEstimatedSize();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.type.Type;
import io.airlift.slice.SizeOf;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;

import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;

public class DifferentialEntropyStateSerializer
implements AccumulatorStateSerializer<DifferentialEntropyState>
Expand All @@ -39,36 +36,9 @@ public Type getSerializedType()
public void serialize(DifferentialEntropyState state, BlockBuilder output)
{
DifferentialEntropyStateStrategy strategy = state.getStrategy();

if (strategy == null) {
SliceOutput sliceOut = Slices.allocate(SizeOf.SIZE_OF_INT).getOutput();
sliceOut.appendInt(0);
VARBINARY.writeSlice(output, sliceOut.getUnderlyingSlice());
return;
}

int requiredBytes = SizeOf.SIZE_OF_INT + // Method
strategy.getRequiredBytesForSerialization(); // stateStrategy;

int requiredBytes = DifferentialEntropyStateStrategy.getRequiredBytesForSerialization(strategy);
SliceOutput sliceOut = Slices.allocate(requiredBytes).getOutput();

if (strategy instanceof UnweightedReservoirSampleStateStrategy) {
sliceOut.appendInt(1);
}
else if (strategy instanceof WeightedReservoirSampleStateStrategy) {
sliceOut.appendInt(2);
}
else if (strategy instanceof FixedHistogramMleStateStrategy) {
sliceOut.appendInt(3);
}
else if (strategy instanceof FixedHistogramJacknifeStateStrategy) {
sliceOut.appendInt(4);
}
else {
verify(false, format("Strategy cannot be serialized: %s", strategy.getClass().getSimpleName()));
}

strategy.serialize(sliceOut);
DifferentialEntropyStateStrategy.serialize(strategy, sliceOut);
VARBINARY.writeSlice(output, sliceOut.getUnderlyingSlice());
}

Expand All @@ -79,25 +49,9 @@ public void deserialize(
DifferentialEntropyState state)
{
SliceInput input = VARBINARY.getSlice(block, index).getInput();
int method = input.readInt();
switch (method) {
case 0:
verify(state.getStrategy() == null, "strategy is not null for null method");
return;
case 1:
state.setStrategy(UnweightedReservoirSampleStateStrategy.deserialize(input));
return;
case 2:
state.setStrategy(WeightedReservoirSampleStateStrategy.deserialize(input));
return;
case 3:
state.setStrategy(FixedHistogramMleStateStrategy.deserialize(input));
return;
case 4:
state.setStrategy(FixedHistogramJacknifeStateStrategy.deserialize(input));
return;
default:
verify(false, format("Unknown method code when deserializing: %s", method));
DifferentialEntropyStateStrategy strategy = DifferentialEntropyStateStrategy.deserialize(input);
if (strategy != null) {
state.setStrategy(strategy);
}
}
}
Loading