Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public BuiltInAggregationFunctionImplementation specialize(BoundVariables variab
// Build state factory and serializer
Class<?> stateClass = concreteImplementation.getStateClass();
AccumulatorStateSerializer<?> stateSerializer = getAccumulatorStateSerializer(concreteImplementation, variables, functionAndTypeManager, stateClass, classLoader);
AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader);
AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(stateClass, variables.getTypeVariables(), classLoader);

// Bind provided dependencies to aggregation method handlers
MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), variables, functionAndTypeManager);
Expand Down Expand Up @@ -188,7 +188,7 @@ private static AccumulatorStateSerializer<?> getAccumulatorStateSerializer(Aggre
}
}
else {
stateSerializer = generateStateSerializer(stateClass, classLoader);
stateSerializer = generateStateSerializer(stateClass, variables.getTypeVariables(), classLoader);
}
return stateSerializer;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import com.facebook.presto.spi.function.AccumulatorStateFactory;
import com.facebook.presto.spi.function.AccumulatorStateMetadata;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator;
import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression;
import com.google.common.collect.ImmutableList;
Expand All @@ -48,10 +49,12 @@
import org.openjdk.jol.info.ClassLayout;

import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -146,9 +149,9 @@ public static <T> AccumulatorStateSerializer<T> generateStateSerializer(Class<T>
AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz);
if (metadata != null && metadata.stateSerializerClass() != void.class) {
try {
return (AccumulatorStateSerializer<T>) metadata.stateSerializerClass().getConstructor().newInstance();
return getAccumulatorStateMetadataInstance(metadata.stateSerializerClass(), fieldTypes);
}
catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
Expand Down Expand Up @@ -178,6 +181,30 @@ public static <T> AccumulatorStateSerializer<T> generateStateSerializer(Class<T>
}
}

private static <T> T getAccumulatorStateMetadataInstance(Class<?> clazz, Map<String, Type> fieldTypes)
throws InvocationTargetException, InstantiationException, IllegalAccessException
{
Optional<Constructor<?>> constructor = Arrays.stream(clazz.getConstructors())
.filter(cons -> Modifier.isPublic(cons.getModifiers()))
.filter(cons -> Arrays.stream(cons.getParameters()).allMatch(param ->
Type.class.equals(param.getType()) &&
// parameter must have valid @TypeParameter annotation
param.isAnnotationPresent(TypeParameter.class) &&
fieldTypes.containsKey(param.getAnnotation(TypeParameter.class).value())))
// this will only run with n > 1 values left
.reduce((first, second) -> {
throw new IllegalArgumentException("Multiple ambiguous annotated constructors in " + clazz + ". Only one valid constructor is allowed.");
});
if (!constructor.isPresent()) {
throw new IllegalArgumentException("Unable to find a suitable constructor for accumulator metadata class " + clazz);
}
Constructor<?> cons = constructor.get();
Object[] params = Arrays.stream(cons.getParameters())
.map(param -> fieldTypes.get(param.getAnnotation(TypeParameter.class).value()))
.toArray();
return (T) cons.newInstance(params);
}

private static void generateGetSerializedType(ClassDefinition definition, List<StateField> fields, CallSiteBinder callSiteBinder)
{
BytecodeBlock body = definition.declareMethod(a(PUBLIC), "getSerializedType", type(Type.class)).getBody();
Expand Down Expand Up @@ -351,9 +378,9 @@ public static <T> AccumulatorStateFactory<T> generateStateFactory(Class<T> clazz
AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz);
if (metadata != null && metadata.stateFactoryClass() != void.class) {
try {
return (AccumulatorStateFactory<T>) metadata.stateFactoryClass().getConstructor().newInstance();
return getAccumulatorStateMetadataInstance(metadata.stateFactoryClass(), fieldTypes);
}
catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import com.facebook.presto.operator.aggregation.state.VarianceState;
import com.facebook.presto.spi.function.AccumulatorState;
import com.facebook.presto.spi.function.AccumulatorStateFactory;
import com.facebook.presto.spi.function.AccumulatorStateMetadata;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.GroupedAccumulatorState;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.util.Reflection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -59,6 +61,8 @@
import static io.airlift.slice.Slices.utf8Slice;
import static io.airlift.slice.Slices.wrappedDoubleArray;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertThrows;
import static org.testng.Assert.assertTrue;

public class TestStateCompiler
{
Expand Down Expand Up @@ -361,6 +365,30 @@ public void testComplexStateEstimatedSize()
}
}

@Test
public void testStateSerializerConstructorsWithMetadata()
{
Map<String, Type> fields = ImmutableMap.of("T", BIGINT, "E", VARCHAR);
Object stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerNoType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerNoType.class.getClassLoader()));
assertTrue(stateSerializer instanceof TestAccumulatorSerializerNoType);

stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerMultipleType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerMultipleType.class.getClassLoader()));
assertTrue(stateSerializer instanceof TestAccumulatorSerializerMultipleType);

stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerSingleType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerSingleType.class.getClassLoader()));
assertTrue(stateSerializer instanceof TestAccumulatorSerializerSingleType);

assertThrows(() -> StateCompiler.generateStateSerializer(
TestAccumulatorSerializerUntyped.class,
fields,
new DynamicClassLoader(TestAccumulatorSerializerUntyped.class.getClassLoader())));

assertThrows(() -> StateCompiler.generateStateSerializer(
TestAccumulatorAmbiguousConstructor.class,
fields,
new DynamicClassLoader(TestAccumulatorAmbiguousConstructor.class.getClassLoader())));
}

public interface TestComplexState
extends AccumulatorState
{
Expand Down Expand Up @@ -428,4 +456,77 @@ public interface SliceState

void setSlice(Slice slice);
}

private abstract static class TestAccumulatorSerializer
implements AccumulatorStateSerializer<Object>
{
@Override
public Type getSerializedType()
{
return null;
}

@Override
public void serialize(Object state, BlockBuilder out)
{}

@Override
public void deserialize(Block block, int index, Object state)
{}
}

@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerSingleType.class)
public static class TestAccumulatorSerializerSingleType
extends TestAccumulatorSerializer
{
public TestAccumulatorSerializerSingleType(@TypeParameter("E") Type first)
{}
}

@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerMultipleType.class)
public static class TestAccumulatorSerializerMultipleType
extends TestAccumulatorSerializer
{
public TestAccumulatorSerializerMultipleType(@TypeParameter("E") Type first, @TypeParameter("T") Type second)
{}
}

@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerNoType.class)
public static class TestAccumulatorSerializerNoType
extends TestAccumulatorSerializer
{
public TestAccumulatorSerializerNoType()
{}
}

@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorAmbiguousConstructor.class)
public static class TestAccumulatorAmbiguousConstructor
extends TestAccumulatorSerializer
{
public TestAccumulatorAmbiguousConstructor()
{}

public TestAccumulatorAmbiguousConstructor(@TypeParameter("E") Type type)
{}
}

// test all invalid constructor types
@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerUntyped.class)
public static class TestAccumulatorSerializerUntyped
extends TestAccumulatorSerializer
{
public TestAccumulatorSerializerUntyped(Type x)
{}

public TestAccumulatorSerializerUntyped(int y)
{}

// type parameter G should not be in passed fields
public TestAccumulatorSerializerUntyped(@TypeParameter("G") Object y)
{}

// type parameter G should not be in passed fields
public TestAccumulatorSerializerUntyped(@TypeParameter("E") Type x, @TypeParameter("G") Long y)
{}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ private static void verifyExactOutputFunction(MethodHandle method, List<Accumula
return;
}
Class<?>[] parameterTypes = method.type().parameterArray();
checkArgument(parameterTypes.length == stateDescriptors.size() + 1, "Number of arguments for combine function must be exactly one plus than number of states.");
checkArgument(parameterTypes.length == stateDescriptors.size() + 1, "Number of arguments for output function must be exactly one plus than number of states.");
for (int i = 0; i < stateDescriptors.size(); i++) {
checkArgument(parameterTypes[i].equals(stateDescriptors.get(i).getStateInterface()), format("Type for Parameter index %d is unexpected", i));
}
Expand Down