diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java index 8f68b94935602..3cf13c853b6f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java @@ -26,6 +26,8 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.OutputFunction; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ApproximateCountDistinctAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ApproximateCountDistinctAggregation.java index 2570a9887b8a8..b162c9fb14244 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ApproximateCountDistinctAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ApproximateCountDistinctAggregation.java @@ -18,6 +18,8 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CombineFunction; import com.facebook.presto.spi.function.InputFunction; import com.facebook.presto.spi.function.OperatorDependency; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DefaultApproximateCountDistinctAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DefaultApproximateCountDistinctAggregation.java index bc85ae27eaa2b..afdf7d0153e3a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DefaultApproximateCountDistinctAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DefaultApproximateCountDistinctAggregation.java @@ -18,6 +18,8 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CombineFunction; import com.facebook.presto.spi.function.InputFunction; import com.facebook.presto.spi.function.OperatorDependency; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/annotations/TypeImplementationDependency.java b/presto-main/src/main/java/com/facebook/presto/operator/annotations/TypeImplementationDependency.java index 428d7de0ec4f8..2bd7efc5de760 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/annotations/TypeImplementationDependency.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/annotations/TypeImplementationDependency.java @@ -19,6 +19,8 @@ import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; +import java.util.Objects; + import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static java.util.Objects.requireNonNull; @@ -38,4 +40,24 @@ public Type resolve(BoundVariables boundVariables, TypeManager typeManager, Func { return typeManager.getType(applyBoundVariables(signature, boundVariables)); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + TypeImplementationDependency that = (TypeImplementationDependency) o; + return Objects.equals(signature, that.signature); + } + + @Override + public int hashCode() + { + return Objects.hash(signature); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapDistinctFromOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapDistinctFromOperator.java index db03040f532d9..1174446d33afa 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapDistinctFromOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapDistinctFromOperator.java @@ -17,7 +17,6 @@ import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarOperator; -import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.StandardTypes; @@ -47,9 +46,9 @@ public static boolean isDistinctFrom( MethodHandle valueDistinctFromFunction, @TypeParameter("K") Type keyType, @TypeParameter("V") Type valueType, - @SqlNullable @SqlType("map(K,V)") Block leftMapBlock, + @SqlType("map(K,V)") Block leftMapBlock, @IsNull boolean leftMapNull, - @SqlNullable @SqlType("map(K,V)") Block rightMapBlock, + @SqlType("map(K,V)") Block rightMapBlock, @IsNull boolean rightMapNull) { if (leftMapNull != rightMapNull) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index c2fd5cb86d96b..efe8da74de685 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -18,8 +18,7 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.operator.ParametricImplementationsGroup; -import com.facebook.presto.operator.scalar.annotations.ScalarImplementation; -import com.facebook.presto.operator.scalar.annotations.ScalarImplementation.MethodHandleAndConstructor; +import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.TypeManager; import com.google.common.annotations.VisibleForTesting; @@ -38,12 +37,12 @@ public class ParametricScalar extends SqlScalarFunction { private final ScalarHeader details; - private final ParametricImplementationsGroup implementations; + private final ParametricImplementationsGroup implementations; public ParametricScalar( Signature signature, ScalarHeader details, - ParametricImplementationsGroup implementations) + ParametricImplementationsGroup implementations) { super(signature); this.details = requireNonNull(details); @@ -69,7 +68,7 @@ public String getDescription() } @VisibleForTesting - public ParametricImplementationsGroup getImplementations() + public ParametricImplementationsGroup getImplementations() { return implementations; } @@ -79,44 +78,28 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in { Signature boundSignature = applyBoundVariables(getSignature(), boundVariables, arity); if (implementations.getExactImplementations().containsKey(boundSignature)) { - ScalarImplementation implementation = implementations.getExactImplementations().get(boundSignature); - Optional methodHandleAndConstructor = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); - checkCondition(methodHandleAndConstructor.isPresent(), FUNCTION_IMPLEMENTATION_ERROR, String.format("Exact implementation of %s do not match expected java types.", boundSignature.getName())); - return new ScalarFunctionImplementation( - implementation.isNullable(), - implementation.getArgumentProperties(), - methodHandleAndConstructor.get().getMethodHandle(), - methodHandleAndConstructor.get().getConstructor(), - isDeterministic()); + ParametricScalarImplementation implementation = implementations.getExactImplementations().get(boundSignature); + Optional scalarFunctionImplementation = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry, isDeterministic()); + checkCondition(scalarFunctionImplementation.isPresent(), FUNCTION_IMPLEMENTATION_ERROR, String.format("Exact implementation of %s do not match expected java types.", boundSignature.getName())); + return scalarFunctionImplementation.get(); } ScalarFunctionImplementation selectedImplementation = null; - for (ScalarImplementation implementation : implementations.getSpecializedImplementations()) { - Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); - if (methodHandle.isPresent()) { + for (ParametricScalarImplementation implementation : implementations.getSpecializedImplementations()) { + Optional scalarFunctionImplementation = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry, isDeterministic()); + if (scalarFunctionImplementation.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); - selectedImplementation = new ScalarFunctionImplementation( - implementation.isNullable(), - implementation.getArgumentProperties(), - methodHandle.get().getMethodHandle(), - methodHandle.get().getConstructor(), - isDeterministic()); + selectedImplementation = scalarFunctionImplementation.get(); } } if (selectedImplementation != null) { return selectedImplementation; } - - for (ScalarImplementation implementation : implementations.getGenericImplementations()) { - Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); - if (methodHandle.isPresent()) { + for (ParametricScalarImplementation implementation : implementations.getGenericImplementations()) { + Optional scalarFunctionImplementation = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry, isDeterministic()); + if (scalarFunctionImplementation.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); - selectedImplementation = new ScalarFunctionImplementation( - implementation.isNullable(), - implementation.getArgumentProperties(), - methodHandle.get().getMethodHandle(), - methodHandle.get().getConstructor(), - isDeterministic()); + selectedImplementation = scalarFunctionImplementation.get(); } } if (selectedImplementation != null) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java similarity index 50% rename from presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java rename to presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java index 9795884977125..06b593298ad35 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java @@ -23,7 +23,11 @@ import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; @@ -44,9 +48,11 @@ import java.lang.reflect.Method; import java.lang.reflect.Parameter; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Stream; @@ -65,8 +71,8 @@ import static com.facebook.presto.operator.annotations.ImplementationDependency.validateImplementationDependencyAnnotation; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.functionTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; -import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentType.FUNCTION_TYPE; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentType.VALUE_TYPE; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.BLOCK_AND_POSITION; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_BOXED_TYPE; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_NULL_FLAG; @@ -76,99 +82,80 @@ import static com.facebook.presto.util.Reflection.constructorMethodHandle; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.lang.String.format; import static java.lang.invoke.MethodHandles.permuteArguments; import static java.lang.reflect.Modifier.isStatic; import static java.util.Objects.requireNonNull; -public class ScalarImplementation +public class ParametricScalarImplementation implements ParametricImplementation { private final Signature signature; - private final boolean nullable; - private final List argumentProperties; - private final MethodHandle methodHandle; - private final List dependencies; - private final Optional constructor; - private final List constructorDependencies; - private final List> argumentNativeContainerTypes; + private final List>> argumentNativeContainerTypes; // argument native container type is Optional.empty() for function type private final Map> specializedTypeParameters; + private final Class returnNativeContainerType; + private final List choices; - public ScalarImplementation( + private ParametricScalarImplementation( Signature signature, - boolean nullable, - List argumentProperties, - MethodHandle methodHandle, - List dependencies, - Optional constructor, - List constructorDependencies, - List> argumentNativeContainerTypes, - Map> specializedTypeParameters) + List>> argumentNativeContainerTypes, + Map> specializedTypeParameters, + List choices, + Class returnContainerType) { this.signature = requireNonNull(signature, "signature is null"); - this.nullable = nullable; - this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null")); - this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); - this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null")); - this.constructor = requireNonNull(constructor, "constructor is null"); - this.constructorDependencies = ImmutableList.copyOf(requireNonNull(constructorDependencies, "constructorDependencies is null")); this.argumentNativeContainerTypes = ImmutableList.copyOf(requireNonNull(argumentNativeContainerTypes, "argumentNativeContainerTypes is null")); this.specializedTypeParameters = ImmutableMap.copyOf(requireNonNull(specializedTypeParameters, "specializedTypeParameters is null")); + this.choices = requireNonNull(choices, "choices is null"); + this.returnNativeContainerType = requireNonNull(returnContainerType, "return native container type is null"); + + for (int i = 1; i < choices.size(); i++) { + checkCondition(Objects.equals(choices.get(i).getDependencies(), choices.get(0).getDependencies()), FUNCTION_IMPLEMENTATION_ERROR, "Implementations for the same function signature must have matching dependencies: %s", signature); + checkCondition(Objects.equals(choices.get(i).getConstructorDependencies(), choices.get(0).getConstructorDependencies()), FUNCTION_IMPLEMENTATION_ERROR, "Implementations for the same function signature must have matching constructor dependencies: %s", signature); + checkCondition(Objects.equals(choices.get(i).getConstructor(), choices.get(0).getConstructor()), FUNCTION_IMPLEMENTATION_ERROR, "Implementations for the same function signature must have matching constructors: %s", signature); + } } - public Optional specialize(Signature boundSignature, BoundVariables boundVariables, TypeManager typeManager, FunctionRegistry functionRegistry) + public Optional specialize(Signature boundSignature, BoundVariables boundVariables, TypeManager typeManager, FunctionRegistry functionRegistry, boolean isDeterministic) { + List implementationChoices = new ArrayList<>(); for (Map.Entry> entry : specializedTypeParameters.entrySet()) { if (!entry.getValue().isAssignableFrom(boundVariables.getTypeVariable(entry.getKey()).getJavaType())) { return Optional.empty(); } } - Class returnContainerType = getNullAwareReturnType(typeManager.getType(boundSignature.getReturnType()).getJavaType(), nullable); - if (!returnContainerType.equals(methodHandle.type().returnType())) { + + if (!returnNativeContainerType.equals(typeManager.getType(boundSignature.getReturnType()).getJavaType())) { return Optional.empty(); } + for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) { - ScalarFunctionImplementation.ArgumentProperty argumentProperty = argumentProperties.get(i); if (boundSignature.getArgumentTypes().get(i).getBase().equals(FunctionType.NAME)) { - if (argumentProperty.getArgumentType() != FUNCTION_TYPE) { + if (argumentNativeContainerTypes.get(i).isPresent()) { return Optional.empty(); } } else { - if (argumentProperty.getArgumentType() != VALUE_TYPE) { + if (!argumentNativeContainerTypes.get(i).isPresent()) { return Optional.empty(); } Class argumentType = typeManager.getType(boundSignature.getArgumentTypes().get(i)).getJavaType(); - Class argumentContainerType = getNullAwareContainerType(argumentType, argumentProperty.getNullConvention()); - if (!argumentNativeContainerTypes.get(i).isAssignableFrom(argumentContainerType)) { + if (!argumentNativeContainerTypes.get(i).get().isAssignableFrom(argumentType)) { return Optional.empty(); } } } - MethodHandle boundMethodHandle = bindDependencies(this.methodHandle, dependencies, boundVariables, typeManager, functionRegistry); - Optional boundConstructor = this.constructor.map(handle -> bindDependencies(handle, constructorDependencies, boundVariables, typeManager, functionRegistry)); - return Optional.of(new MethodHandleAndConstructor(boundMethodHandle, boundConstructor)); - } - private static Class getNullAwareReturnType(Class clazz, boolean nullable) - { - if (nullable) { - return Primitives.wrap(clazz); - } - return clazz; - } + for (ParametricScalarImplementationChoice choice : choices) { + MethodHandle boundMethodHandle = bindDependencies(choice.methodHandle, choice.getDependencies(), boundVariables, typeManager, functionRegistry); + Optional boundConstructor = choice.constructor.map(handle -> bindDependencies(handle, choice.getConstructorDependencies(), boundVariables, typeManager, functionRegistry)); - private static Class getNullAwareContainerType(Class clazz, NullConvention nullConvention) - { - if (clazz == void.class) { - return Primitives.wrap(clazz); + implementationChoices.add(new ScalarImplementationChoice(choice.nullable, choice.argumentProperties, boundMethodHandle, boundConstructor)); } - if (nullConvention == USE_BOXED_TYPE) { - return Primitives.wrap(clazz); - } - return clazz; + return Optional.of(new ScalarFunctionImplementation(implementationChoices, isDeterministic)); } @Override @@ -177,47 +164,125 @@ public boolean hasSpecializedTypeParameters() return !specializedTypeParameters.isEmpty(); } + Map> getSpecializedTypeParameters() + { + return specializedTypeParameters; + } + @Override public Signature getSignature() { return signature; } - public boolean isNullable() + List>> getArgumentNativeContainerTypes() { - return nullable; + return argumentNativeContainerTypes; } - public List getArgumentProperties() + public List getDependencies() { - return argumentProperties; + // All choices are required to have the same dependencies at this time. This is asserted in the constructor. + return choices.get(0).getDependencies(); } - public MethodHandle getMethodHandle() + @VisibleForTesting + public List getConstructorDependencies() { - return methodHandle; + // All choices are required to have the same constructor dependencies at this time. This is asserted in the constructor. + return choices.get(0).getConstructorDependencies(); } - public List getDependencies() + Class getReturnNativeContainerType() { - return dependencies; + return returnNativeContainerType; } - @VisibleForTesting - public List getConstructorDependencies() + SpecializedSignature getSpecializedSignature() { - return constructorDependencies; + return new SpecializedSignature( + signature, + argumentNativeContainerTypes, + specializedTypeParameters, + returnNativeContainerType); } - public static final class MethodHandleAndConstructor + public Builder builder() { + return new Builder(signature, argumentNativeContainerTypes, specializedTypeParameters, returnNativeContainerType); + } + + public static final class Builder + { + private final Signature signature; + private final List>> argumentNativeContainerTypes; // argument native container type is Optional.empty() for function type + private final Map> specializedTypeParameters; + private final Class returnNativeContainerType; + private final List choices; + + public Builder( + Signature signature, + List>> argumentNativeContainerTypes, + Map> specializedTypeParameters, + Class returnNativeContainerType) + { + this.signature = requireNonNull(signature, "signature is null"); + this.argumentNativeContainerTypes = ImmutableList.copyOf(requireNonNull(argumentNativeContainerTypes, "argumentNativeContainerTypes is null")); + this.specializedTypeParameters = ImmutableMap.copyOf(requireNonNull(specializedTypeParameters, "specializedTypeParameters is null")); + this.choices = new ArrayList<>(); + this.returnNativeContainerType = requireNonNull(returnNativeContainerType, "return native container type is null"); + } + + void addChoices(ParametricScalarImplementation implementation) + { + this.choices.addAll(implementation.choices); + } + + public ParametricScalarImplementation build() + { + choices.sort(ParametricScalarImplementationChoice::compareTo); + return new ParametricScalarImplementation(signature, argumentNativeContainerTypes, specializedTypeParameters, choices, returnNativeContainerType); + } + } + + public static final class ParametricScalarImplementationChoice + implements Comparable + { + private final boolean nullable; + private final List argumentProperties; private final MethodHandle methodHandle; private final Optional constructor; - - public MethodHandleAndConstructor(MethodHandle methodHandle, Optional constructor) + private final List dependencies; + private final List constructorDependencies; + private final int numberOfBlockPositionArguments; + + private ParametricScalarImplementationChoice( + boolean nullable, + List argumentProperties, + MethodHandle methodHandle, + Optional constructor, + List dependencies, + List constructorDependencies) { - this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); + this.nullable = nullable; + this.argumentProperties = argumentProperties; + this.methodHandle = methodHandle; this.constructor = requireNonNull(constructor, "constructor is null"); + this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null")); + this.constructorDependencies = ImmutableList.copyOf(requireNonNull(constructorDependencies, "constructorDependencies is null")); + + int numberOfBlockPositionArguments = 0; + for (ArgumentProperty argumentProperty : argumentProperties) { + if (argumentProperty.getArgumentType() == VALUE_TYPE && argumentProperty.getNullConvention().equals(BLOCK_AND_POSITION)) { + numberOfBlockPositionArguments++; + } + } + this.numberOfBlockPositionArguments = numberOfBlockPositionArguments; + } + + public boolean isNullable() + { + return nullable; } public MethodHandle getMethodHandle() @@ -225,10 +290,72 @@ public MethodHandle getMethodHandle() return methodHandle; } + public List getDependencies() + { + return dependencies; + } + + @VisibleForTesting + List getConstructorDependencies() + { + return constructorDependencies; + } + public Optional getConstructor() { return constructor; } + + @Override + public int compareTo(ParametricScalarImplementationChoice choice) + { + if (choice.numberOfBlockPositionArguments < this.numberOfBlockPositionArguments) { + return 1; + } + return -1; + } + } + + public static final class SpecializedSignature + { + private final Signature signature; + private final List>> argumentNativeContainerTypes; + private final Map> specializedTypeParameters; + private final Class returnNativeContainerType; + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SpecializedSignature that = (SpecializedSignature) o; + return Objects.equals(signature, that.signature) && + Objects.equals(argumentNativeContainerTypes, that.argumentNativeContainerTypes) && + Objects.equals(specializedTypeParameters, that.specializedTypeParameters) && + Objects.equals(returnNativeContainerType, that.returnNativeContainerType); + } + + @Override + public int hashCode() + { + return Objects.hash(signature, argumentNativeContainerTypes, specializedTypeParameters, returnNativeContainerType); + } + + private SpecializedSignature( + Signature signature, + List>> argumentNativeContainerTypes, + Map> specializedTypeParameters, + Class returnNativeContainerType) + { + this.signature = signature; + this.argumentNativeContainerTypes = argumentNativeContainerTypes; + this.specializedTypeParameters = specializedTypeParameters; + this.returnNativeContainerType = returnNativeContainerType; + } } public static final class Parser @@ -238,7 +365,7 @@ public static final class Parser private final List argumentProperties = new ArrayList<>(); private final TypeSignature returnType; private final List argumentTypes = new ArrayList<>(); - private final List> argumentNativeContainerTypes = new ArrayList<>(); + private final List>> argumentNativeContainerTypes = new ArrayList<>(); private final MethodHandle methodHandle; private final List dependencies = new ArrayList<>(); private final Set typeParameters = new LinkedHashSet<>(); @@ -248,6 +375,9 @@ public static final class Parser private final Optional constructorMethodHandle; private final List constructorDependencies = new ArrayList<>(); private final List longVariableConstraints; + private final Class returnNativeContainerType; + + private final List choices = new ArrayList<>(); private Parser(String functionName, Method method, Optional> constructor) { @@ -255,8 +385,7 @@ private Parser(String functionName, Method method, Optional> cons this.nullable = method.getAnnotation(SqlNullable.class) != null; checkArgument(nullable || !containsLegacyNullable(method.getAnnotations()), "Method [%s] is annotated with @Nullable but not @SqlNullable", method); - Stream.of(method.getAnnotationsByType(TypeParameter.class)) - .forEach(typeParameters::add); + typeParameters.addAll(Arrays.asList(method.getAnnotationsByType(TypeParameter.class))); literalParameters = parseLiteralParameters(method); typeParameterNames = typeParameters.stream() @@ -268,6 +397,8 @@ private Parser(String functionName, Method method, Optional> cons this.returnType = parseTypeSignature(returnType.value(), literalParameters); Class actualReturnType = method.getReturnType(); + this.returnNativeContainerType = Primitives.unwrap(actualReturnType); + if (Primitives.isWrapperType(actualReturnType)) { checkArgument(nullable, "Method [%s] has wrapper return type %s but is missing @SqlNullable", method, actualReturnType.getSimpleName()); } @@ -285,21 +416,27 @@ else if (actualReturnType.isPrimitive()) { "Expected type parameter to only contain A-Z and 0-9 (starting with A-Z), but got %s on method [%s]", typeParameter.value(), method); } + inferSpecialization(method, actualReturnType, returnType.value(), nullable); parseArguments(method); this.constructorMethodHandle = getConstructor(method, constructor); this.methodHandle = getMethodHandle(method); + + ParametricScalarImplementationChoice choice = new ParametricScalarImplementationChoice(nullable, argumentProperties, methodHandle, constructorMethodHandle, dependencies, constructorDependencies); + choices.add(choice); } private void parseArguments(Method method) { - for (int i = 0; i < method.getParameterCount(); i++) { + int i = 0; + while (i < method.getParameterCount()) { Parameter parameter = method.getParameters()[i]; Class parameterType = parameter.getType(); // Skip injected parameters if (parameterType == ConnectorSession.class) { + i++; continue; } @@ -308,10 +445,11 @@ private void parseArguments(Method method) // check if only declared typeParameters and literalParameters are used validateImplementationDependencyAnnotation(method, implementationDependency.get(), typeParameterNames, literalParameters); dependencies.add(createDependency(implementationDependency.get(), literalParameters)); + i++; } else { Annotation[] annotations = parameter.getAnnotations(); - checkArgument(!Stream.of(annotations).anyMatch(IsNull.class::isInstance), "Method [%s] has @IsNull parameter that does not follow a @SqlType parameter", method); + checkArgument(Stream.of(annotations).noneMatch(IsNull.class::isInstance), "Method [%s] has @IsNull parameter that does not follow a @SqlType parameter", method); SqlType type = Stream.of(annotations) .filter(SqlType.class::isInstance) @@ -319,64 +457,97 @@ private void parseArguments(Method method) .findFirst() .orElseThrow(() -> new IllegalArgumentException(format("Method [%s] is missing @SqlType annotation for parameter", method))); TypeSignature typeSignature = parseTypeSignature(type.value(), literalParameters); - boolean nullableArgument = Stream.of(annotations).anyMatch(SqlNullable.class::isInstance); - checkArgument(nullableArgument || !containsLegacyNullable(annotations), "Method [%s] has parameter annotated with @Nullable but not @SqlNullable", method); - - boolean hasNullFlag = false; - if (method.getParameterCount() > (i + 1)) { - Annotation[] parameterAnnotations = method.getParameterAnnotations()[i + 1]; - if (Stream.of(parameterAnnotations).anyMatch(IsNull.class::isInstance)) { - Class isNullType = method.getParameterTypes()[i + 1]; - - checkArgument(Stream.of(parameterAnnotations).filter(FunctionsParserHelper::isPrestoAnnotation).allMatch(IsNull.class::isInstance), "Method [%s] has @IsNull parameter that has other annotations", method); - checkArgument(isNullType == boolean.class, "Method [%s] has non-boolean parameter with @IsNull", method); - checkArgument((parameterType == Void.class) || !Primitives.isWrapperType(parameterType), "Method [%s] uses @IsNull following a parameter with boxed primitive type: %s", method, parameterType.getSimpleName()); - - nullableArgument = true; - hasNullFlag = true; - } - } - - if (Primitives.isWrapperType(parameterType)) { - checkArgument(nullableArgument, "Method [%s] has parameter with wrapper type %s that is missing @SqlNullable", method, parameterType.getSimpleName()); - } - else if (parameterType.isPrimitive() && !hasNullFlag) { - checkArgument(!nullableArgument, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName()); - } - - if (typeParameterNames.contains(type.value()) && !(parameterType == Object.class && nullableArgument)) { - // Infer specialization on this type parameter. We don't do this for @SqlNullable Object because it could match a type like BIGINT - Class specialization = specializedTypeParameters.get(type.value()); - Class nativeParameterType = Primitives.unwrap(parameterType); - checkArgument(specialization == null || specialization.equals(nativeParameterType), "Method [%s] type %s has conflicting specializations %s and %s", method, type.value(), specialization, nativeParameterType); - specializedTypeParameters.put(type.value(), nativeParameterType); - } - argumentNativeContainerTypes.add(parameterType); argumentTypes.add(typeSignature); - if (hasNullFlag) { - // skip @IsNull parameter - i++; - } - if (typeSignature.getBase().equals(FunctionType.NAME)) { + // function type checkCondition(parameterType.isAnnotationPresent(FunctionalInterface.class), FUNCTION_IMPLEMENTATION_ERROR, "argument %s is marked as lambda but the function interface class is not annotated: %s", i, methodHandle); argumentProperties.add(functionTypeArgumentProperty(parameterType)); + argumentNativeContainerTypes.add(Optional.empty()); + i++; } else { + // value type NullConvention nullConvention; - if (!nullableArgument) { - nullConvention = RETURN_NULL_ON_NULL; + if (Stream.of(annotations).anyMatch(SqlNullable.class::isInstance)) { + checkCondition(!parameterType.isPrimitive(), FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName()); + + nullConvention = NullConvention.USE_BOXED_TYPE; + } + else if (Stream.of(annotations).anyMatch(BlockPosition.class::isInstance)) { + checkState(method.getParameterCount() > (i + 1)); + checkState(parameterType == Block.class); + + nullConvention = NullConvention.BLOCK_AND_POSITION; + Annotation[] parameterAnnotations = method.getParameterAnnotations()[i + 1]; + checkState(Stream.of(parameterAnnotations).anyMatch(BlockIndex.class::isInstance)); + } + else { + // USE_NULL_FLAG or RETURN_NULL_ON_NULL + checkCondition(parameterType == Void.class || !Primitives.isWrapperType(parameterType), FUNCTION_IMPLEMENTATION_ERROR, "A parameter with USE_NULL_FLAG or RETURN_NULL_ON_NULL convention must not use wrapper type. Found in method [%s]", method); + + boolean useNullFlag = false; + if (method.getParameterCount() > (i + 1)) { + Annotation[] parameterAnnotations = method.getParameterAnnotations()[i + 1]; + if (Stream.of(parameterAnnotations).anyMatch(IsNull.class::isInstance)) { + Class isNullType = method.getParameterTypes()[i + 1]; + + checkArgument(Stream.of(parameterAnnotations).filter(FunctionsParserHelper::isPrestoAnnotation).allMatch(IsNull.class::isInstance), "Method [%s] has @IsNull parameter that has other annotations", method); + checkArgument(isNullType == boolean.class, "Method [%s] has non-boolean parameter with @IsNull", method); + checkArgument((parameterType == Void.class) || !Primitives.isWrapperType(parameterType), "Method [%s] uses @IsNull following a parameter with boxed primitive type: %s", method, parameterType.getSimpleName()); + + useNullFlag = true; + } + } + + if (useNullFlag) { + nullConvention = USE_NULL_FLAG; + } + else { + nullConvention = RETURN_NULL_ON_NULL; + } + } + + if (nullConvention == BLOCK_AND_POSITION) { + argumentNativeContainerTypes.add(Optional.of(type.nativeContainerType())); } else { - nullConvention = hasNullFlag ? USE_NULL_FLAG : USE_BOXED_TYPE; + inferSpecialization(method, parameterType, type.value(), nullConvention); + + checkCondition(type.nativeContainerType().equals(Object.class), FUNCTION_IMPLEMENTATION_ERROR, "@SqlType can only contain an explicitly specified nativeContainerType when using @BlockPosition"); + argumentNativeContainerTypes.add(Optional.of(Primitives.unwrap(parameterType))); } + argumentProperties.add(valueTypeArgumentProperty(nullConvention)); + i += nullConvention.getParameterCount(); } } } } + private void inferSpecialization(Method method, Class parameterType, String typeParameterName, NullConvention nullConventionFlag) + { + checkArgument(nullConventionFlag != BLOCK_AND_POSITION); + + if (nullConventionFlag == USE_BOXED_TYPE) { + inferSpecialization(method, parameterType, typeParameterName, true); + } + else { + inferSpecialization(method, parameterType, typeParameterName, false); + } + } + + private void inferSpecialization(Method method, Class parameterType, String typeParameterName, boolean nullable) + { + if (typeParameterNames.contains(typeParameterName) && !(parameterType == Object.class && nullable)) { + // Infer specialization on this type parameter. We don't do this for @SqlNullable Object because it could match a type like BIGINT + Class specialization = specializedTypeParameters.get(typeParameterName); + Class nativeParameterType = Primitives.unwrap(parameterType); + checkArgument(specialization == null || specialization.equals(nativeParameterType), "Method [%s] type %s has conflicting specializations %s and %s", method, typeParameterName, specialization, nativeParameterType); + specializedTypeParameters.put(typeParameterName, nativeParameterType); + } + } + // Find matching constructor, if this is an instance method, and populate constructorDependencies private Optional getConstructor(Method method, Optional> optionalConstructor) { @@ -427,7 +598,7 @@ private MethodHandle getMethodHandle(Method method) return methodHandle; } - public ScalarImplementation get() + public ParametricScalarImplementation get() { Signature signature = new Signature( functionName, @@ -437,19 +608,16 @@ public ScalarImplementation get() returnType, argumentTypes, false); - return new ScalarImplementation( + + return new ParametricScalarImplementation( signature, - nullable, - argumentProperties, - methodHandle, - dependencies, - constructorMethodHandle, - constructorDependencies, argumentNativeContainerTypes, - specializedTypeParameters); + specializedTypeParameters, + choices, + returnNativeContainerType); } - public static ScalarImplementation parseImplementation(String functionName, Method method, Optional> constructor) + static ParametricScalarImplementation parseImplementation(String functionName, Method method, Optional> constructor) { return new Parser(functionName, method, constructor).get(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java index 92d9dbb9e06a5..751ad049f39b4 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java @@ -18,6 +18,7 @@ import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.annotations.FunctionsParserHelper; import com.facebook.presto.operator.scalar.ParametricScalar; +import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation.SpecializedSignature; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; @@ -26,11 +27,15 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import static com.facebook.presto.operator.scalar.annotations.OperatorValidator.validateOperator; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -65,7 +70,7 @@ private static List findScalarsInFunctionDefinitionClass for (ScalarImplementationHeader header : classHeaders) { Set methods = FunctionsParserHelper.findPublicMethodsWithAnnotation(annotated, SqlType.class, ScalarFunction.class, ScalarOperator.class); - checkArgument(!methods.isEmpty(), "Parametric class [%s] does not have any annotated methods", annotated.getName()); + checkCondition(!methods.isEmpty(), FUNCTION_IMPLEMENTATION_ERROR, "Parametric class [%s] does not have any annotated methods", annotated.getName()); for (Method method : methods) { checkArgument(method.getAnnotation(ScalarFunction.class) == null, "Parametric class method [%s] is annotated with @ScalarFunction", method); checkArgument(method.getAnnotation(ScalarOperator.class) == null, "Parametric class method [%s] is annotated with @ScalarOperator", method); @@ -80,8 +85,8 @@ private static List findScalarsInFunctionSetClass(Class< { ImmutableList.Builder builder = ImmutableList.builder(); for (Method method : FunctionsParserHelper.findPublicMethodsWithAnnotation(annotated, SqlType.class, ScalarFunction.class, ScalarOperator.class)) { - checkArgument((method.getAnnotation(ScalarFunction.class) != null) || (method.getAnnotation(ScalarOperator.class) != null), - "Method [%s] annotated with @SqlType is missing @ScalarFunction or @ScalarOperator", method); + checkCondition((method.getAnnotation(ScalarFunction.class) != null) || (method.getAnnotation(ScalarOperator.class) != null), + FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] annotated with @SqlType is missing @ScalarFunction or @ScalarOperator", method); for (ScalarImplementationHeader header : ScalarImplementationHeader.fromAnnotatedElement(method)) { builder.add(new ScalarHeaderAndMethods(header, ImmutableSet.of(method))); } @@ -93,16 +98,32 @@ private static List findScalarsInFunctionSetClass(Class< private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods scalar, Optional> constructor) { - ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); ScalarImplementationHeader header = scalar.getHeader(); checkArgument(!header.getName().isEmpty()); + Map signatures = new HashMap<>(); for (Method method : scalar.getMethods()) { - ScalarImplementation implementation = ScalarImplementation.Parser.parseImplementation(header.getName(), method, constructor); - implementationsBuilder.addImplementation(implementation); + ParametricScalarImplementation implementation = ParametricScalarImplementation.Parser.parseImplementation(header.getName(), method, constructor); + if (!signatures.containsKey(implementation.getSpecializedSignature())) { + ParametricScalarImplementation.Builder builder = new ParametricScalarImplementation.Builder( + implementation.getSignature(), + implementation.getArgumentNativeContainerTypes(), + implementation.getSpecializedTypeParameters(), + implementation.getReturnNativeContainerType()); + signatures.put(implementation.getSpecializedSignature(), builder); + builder.addChoices(implementation); + } + else { + ParametricScalarImplementation.Builder builder = signatures.get(implementation.getSpecializedSignature()); + builder.addChoices(implementation); + } } - ParametricImplementationsGroup implementations = implementationsBuilder.build(); + ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); + for (ParametricScalarImplementation.Builder implementation : signatures.values()) { + implementationsBuilder.addImplementation(implementation.build()); + } + ParametricImplementationsGroup implementations = implementationsBuilder.build(); Signature scalarSignature = implementations.getSignature(); header.getOperatorType().ifPresent(operatorType -> diff --git a/presto-main/src/main/java/com/facebook/presto/type/UnknownOperators.java b/presto-main/src/main/java/com/facebook/presto/type/UnknownOperators.java index 56b6a59b7c666..94744e3abf901 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/UnknownOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/UnknownOperators.java @@ -103,9 +103,9 @@ public static Long hashCode(@SqlType("unknown") @SqlNullable Void value) @ScalarOperator(IS_DISTINCT_FROM) @SqlType(StandardTypes.BOOLEAN) public static boolean isDistinctFrom( - @SqlType("unknown") @SqlNullable Void left, + @SqlType("unknown") Void left, @IsNull boolean leftNull, - @SqlType("unknown") @SqlNullable Void right, + @SqlType("unknown") Void right, @IsNull boolean rightNull) { return false; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java index 030ab612e6614..225e306d000c2 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java @@ -21,8 +21,6 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.AggregationImplementation; import com.facebook.presto.operator.aggregation.AggregationMetadata; -import com.facebook.presto.operator.aggregation.BlockIndex; -import com.facebook.presto.operator.aggregation.BlockPosition; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.operator.aggregation.LazyAccumulatorFactoryBinder; import com.facebook.presto.operator.aggregation.ParametricAggregation; @@ -39,6 +37,8 @@ import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationState; import com.facebook.presto.spi.function.AggregationStateSerializerFactory; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CombineFunction; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.InputFunction; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForScalars.java b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForScalars.java index fb87ee0afb139..e5fc8ee5457a5 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForScalars.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForScalars.java @@ -148,7 +148,7 @@ public static class WithNullablePrimitiveArgScalarFunction @SqlType(StandardTypes.DOUBLE) public static double fun( @SqlType(StandardTypes.DOUBLE) double v, - @SqlNullable @SqlType(StandardTypes.DOUBLE) double v2, + @SqlType(StandardTypes.DOUBLE) double v2, @IsNull boolean v2isNull) { return v; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java index 23947fdef2e7f..b710e282a1e90 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java @@ -19,6 +19,8 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CombineFunction; import com.facebook.presto.spi.function.InputFunction; import com.facebook.presto.spi.function.OutputFunction; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarImplementationValidation.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestParametricScalarImplementationValidation.java similarity index 84% rename from presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarImplementationValidation.java rename to presto-main/src/test/java/com/facebook/presto/operator/scalar/TestParametricScalarImplementationValidation.java index b705eaa3c6648..1974a62c26adc 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarImplementationValidation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestParametricScalarImplementationValidation.java @@ -26,15 +26,15 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; -public class TestScalarImplementationValidation +public class TestParametricScalarImplementationValidation { - private static final MethodHandle STATE_FACTORY = methodHandle(TestScalarImplementationValidation.class, "createState"); + private static final MethodHandle STATE_FACTORY = methodHandle(TestParametricScalarImplementationValidation.class, "createState"); @Test public void testConnectorSessionPosition() { // Without cached instance factory - MethodHandle validFunctionMethodHandle = methodHandle(TestScalarImplementationValidation.class, "validConnectorSessionParameterPosition", ConnectorSession.class, long.class, long.class); + MethodHandle validFunctionMethodHandle = methodHandle(TestParametricScalarImplementationValidation.class, "validConnectorSessionParameterPosition", ConnectorSession.class, long.class, long.class); ScalarFunctionImplementation validFunction = new ScalarFunctionImplementation( false, ImmutableList.of( @@ -50,7 +50,7 @@ public void testConnectorSessionPosition() ImmutableList.of( valueTypeArgumentProperty(RETURN_NULL_ON_NULL), valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), - methodHandle(TestScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", long.class, long.class, ConnectorSession.class), + methodHandle(TestParametricScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", long.class, long.class, ConnectorSession.class), false); fail("expected exception"); } @@ -59,7 +59,7 @@ public void testConnectorSessionPosition() } // With cached instance factory - MethodHandle validFunctionWithInstanceFactoryMethodHandle = methodHandle(TestScalarImplementationValidation.class, "validConnectorSessionParameterPosition", Object.class, ConnectorSession.class, long.class, long.class); + MethodHandle validFunctionWithInstanceFactoryMethodHandle = methodHandle(TestParametricScalarImplementationValidation.class, "validConnectorSessionParameterPosition", Object.class, ConnectorSession.class, long.class, long.class); ScalarFunctionImplementation validFunctionWithInstanceFactory = new ScalarFunctionImplementation( false, ImmutableList.of( @@ -76,7 +76,7 @@ public void testConnectorSessionPosition() ImmutableList.of( valueTypeArgumentProperty(RETURN_NULL_ON_NULL), valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), - methodHandle(TestScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", Object.class, long.class, long.class, ConnectorSession.class), + methodHandle(TestParametricScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", Object.class, long.class, long.class, ConnectorSession.class), Optional.of(STATE_FACTORY), false); fail("expected exception"); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java index 0d6bd6baa32be..78154c4c5e13e 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java @@ -15,6 +15,7 @@ import com.facebook.presto.metadata.FunctionListBuilder; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlNullable; @@ -42,7 +43,7 @@ public static final class BogusParametricMethodAnnotation public static void bad() {} } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Parametric class .* does not have any annotated methods") + @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "Parametric class .* does not have any annotated methods") public void testNoParametricMethods() { extractParametricScalar(NoParametricMethods.class); @@ -64,7 +65,7 @@ public static final class MethodMissingReturnAnnotation public static void bad() {} } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* annotated with @SqlType is missing @ScalarFunction or @ScalarOperator") + @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "Method .* annotated with @SqlType is missing @ScalarFunction or @ScalarOperator") public void testMethodMissingScalarAnnotation() { extractScalars(MethodMissingScalarAnnotation.class); @@ -110,7 +111,7 @@ public static long bad() } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* has parameter with wrapper type Boolean that is missing @SqlNullable") + @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "A parameter with USE_NULL_FLAG or RETURN_NULL_ON_NULL convention must not use wrapper type. Found in method .*") public void testPrimitiveWrapperParameterWithoutNullable() { extractScalars(PrimitiveWrapperParameterWithoutNullable.class); @@ -126,7 +127,7 @@ public static long bad(@SqlType(StandardTypes.BOOLEAN) Boolean boxed) } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* has parameter with primitive type double annotated with @SqlNullable") + @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "Method .* has parameter with primitive type double annotated with @SqlNullable") public void testPrimitiveParameterWithNullable() { extractScalars(PrimitiveParameterWithNullable.class); @@ -191,22 +192,6 @@ public static Long bad() } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* has parameter annotated with @Nullable but not @SqlNullable") - public void testParameterWithLegacyNullable() - { - extractScalars(ParameterWithLegacyNullable.class); - } - - public static final class ParameterWithLegacyNullable - { - @ScalarFunction - @SqlType(StandardTypes.BIGINT) - public static long bad(@Nullable @SqlType(StandardTypes.DOUBLE) Double value) - { - return 0; - } - } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* has @IsNull parameter that does not follow a @SqlType parameter") public void testParameterWithConnectorAndIsNull() { @@ -255,7 +240,7 @@ public static long bad(@SqlType(StandardTypes.BIGINT) long value, @IsNull int is } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* uses @IsNull following a parameter with boxed primitive type: Long") + @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "A parameter with USE_NULL_FLAG or RETURN_NULL_ON_NULL convention must not use wrapper type. Found in method .*") public void testParameterWithBoxedPrimitiveIsNull() { extractScalars(ParameterWithBoxedPrimitiveIsNull.class); @@ -265,7 +250,7 @@ public static final class ParameterWithBoxedPrimitiveIsNull { @ScalarFunction @SqlType(StandardTypes.BIGINT) - public static long bad(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value, @IsNull boolean isNull) + public static long bad(@SqlType(StandardTypes.BIGINT) Long value, @IsNull boolean isNull) { return 0; } diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestBlockAndPositionNullConvention.java b/presto-main/src/test/java/com/facebook/presto/type/TestBlockAndPositionNullConvention.java index 236c646e59a2e..4e3fdf0723986 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestBlockAndPositionNullConvention.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestBlockAndPositionNullConvention.java @@ -13,30 +13,27 @@ */ package com.facebook.presto.type; -import com.facebook.presto.metadata.BoundVariables; -import com.facebook.presto.metadata.FunctionKind; -import com.facebook.presto.metadata.FunctionRegistry; -import com.facebook.presto.metadata.Signature; -import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.operator.scalar.AbstractTestFunctions; -import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; -import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import com.facebook.presto.spi.block.Block; -import com.facebook.presto.spi.type.TypeManager; -import com.google.common.collect.ImmutableList; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import io.airlift.slice.Slice; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import java.lang.invoke.MethodHandle; -import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; -import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; -import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.BLOCK_AND_POSITION; -import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.type.TestBlockAndPositionNullConvention.FunctionWithBlockAndPositionConvention.BLOCK_AND_POSITION_CONVENTION; -import static com.facebook.presto.util.Reflection.methodHandle; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; public class TestBlockAndPositionNullConvention @@ -45,76 +42,132 @@ public class TestBlockAndPositionNullConvention @BeforeClass public void setUp() { - registerScalarFunction(BLOCK_AND_POSITION_CONVENTION); + registerParametricScalar(FunctionWithBlockAndPositionConvention.class); } @Test public void testBlockPosition() { - assertFunction("identityFunction(9876543210)", BIGINT, 9876543210L); - assertFunction("identityFunction(bound_long)", BIGINT, 1234L); - assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPosition.get()); + assertFunction("test_block_position(9876543210)", BIGINT, 9876543210L); + assertFalse(FunctionWithBlockAndPositionConvention.hitBlockPositionBigint.get()); + + assertFunction("test_block_position(bound_long)", BIGINT, 1234L); + assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPositionBigint.get()); + + assertFunction("test_block_position(3.0E0)", DOUBLE, 3.0); + assertFalse(FunctionWithBlockAndPositionConvention.hitBlockPositionDouble.get()); + + assertFunction("test_block_position(bound_double)", DOUBLE, 12.34); + assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPositionDouble.get()); + + assertFunction("test_block_position(bound_string)", VARCHAR, "hello"); + assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPositionSlice.get()); + + // TODO: add adaptations so these will pass + //assertFunction("test_block_position(null)", UNKNOWN, null); + //assertFalse(FunctionWithBlockAndPositionConvention.hitBlockPositionObject.get()); + + assertFunction("test_block_position(false)", BOOLEAN, false); + assertFalse(FunctionWithBlockAndPositionConvention.hitBlockPositionBoolean.get()); + + assertFunction("test_block_position(bound_boolean)", BOOLEAN, true); + assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPositionBoolean.get()); } + @ScalarFunction("test_block_position") public static class FunctionWithBlockAndPositionConvention - extends SqlScalarFunction { - private static final AtomicBoolean hitBlockPosition = new AtomicBoolean(); - public static final FunctionWithBlockAndPositionConvention BLOCK_AND_POSITION_CONVENTION = new FunctionWithBlockAndPositionConvention(); + private static final AtomicBoolean hitBlockPositionBigint = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionDouble = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionSlice = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionBoolean = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionObject = new AtomicBoolean(); - private static final MethodHandle METHOD_HANDLE_BLOCK_AND_POSITION = methodHandle(FunctionWithBlockAndPositionConvention.class, "getBlockPosition", Block.class, int.class); - private static final MethodHandle METHOD_HANDLE_NULL_ON_NULL = methodHandle(FunctionWithBlockAndPositionConvention.class, "getLong", long.class); + /* + // generic implementations + // these will not work right now because MethodHandle is not properly adapted - protected FunctionWithBlockAndPositionConvention() + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Object generic(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Object object) { - super(new Signature("identityFunction", FunctionKind.SCALAR, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); + return object; } - @Override - public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Object generic(@TypeParameter("E") Type type, @BlockPosition @SqlType("E") Block block, @BlockIndex int position) { - return new ScalarFunctionImplementation( - ImmutableList.of( - new ScalarImplementationChoice( - false, - ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), - METHOD_HANDLE_NULL_ON_NULL, - Optional.empty()), - new ScalarImplementationChoice( - false, - ImmutableList.of(valueTypeArgumentProperty(BLOCK_AND_POSITION)), - METHOD_HANDLE_BLOCK_AND_POSITION, - Optional.empty())), - isDeterministic()); + hitBlockPositionObject.set(true); + return TypeUtils.readNativeValue(type, block, position); } + */ + + // specialized - public static long getBlockPosition(Block block, int position) + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Slice slice) { - hitBlockPosition.set(true); - return BIGINT.getLong(block, position); + return slice; } - public static long getLong(long number) + @TypeParameter("E") + @SqlType("E") + public static Slice specializedSlice(@TypeParameter("E") Type type, @BlockPosition @SqlType(value = "E", nativeContainerType = Slice.class) Block block, @BlockIndex int position) + { + hitBlockPositionSlice.set(true); + return type.getSlice(block, position); + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlType("E") boolean bool) + { + return bool; + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @BlockPosition @SqlType(value = "E", nativeContainerType = boolean.class) Block block, @BlockIndex int position) + { + hitBlockPositionBoolean.set(true); + return type.getBoolean(block, position); + } + + // exact + + @SqlType(StandardTypes.BIGINT) + public static long getLong(@SqlType(StandardTypes.BIGINT) long number) { return number; } - @Override - public boolean isDeterministic() + @SqlType(StandardTypes.BIGINT) + public static long getBlockPosition(@BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) Block block, @BlockIndex int position) { - return true; + hitBlockPositionBigint.set(true); + return BIGINT.getLong(block, position); } - @Override - public boolean isHidden() + @SqlType(StandardTypes.DOUBLE) + @SqlNullable + public static Double getDouble(@SqlType(StandardTypes.DOUBLE) double number) { - return false; + return number; } - @Override - public String getDescription() + @SqlType(StandardTypes.DOUBLE) + @SqlNullable + public static Double getDouble(@BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) Block block, @BlockIndex int position) { - return ""; + hitBlockPositionDouble.set(true); + return DOUBLE.getDouble(block, position); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockIndex.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/BlockIndex.java similarity index 94% rename from presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockIndex.java rename to presto-spi/src/main/java/com/facebook/presto/spi/function/BlockIndex.java index 065de6b9928bc..1b5d9bf87b806 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockIndex.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/BlockIndex.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.operator.aggregation; +package com.facebook.presto.spi.function; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockPosition.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/BlockPosition.java similarity index 94% rename from presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockPosition.java rename to presto-spi/src/main/java/com/facebook/presto/spi/function/BlockPosition.java index b4a62b5da6b96..cb060b241473e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockPosition.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/BlockPosition.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.operator.aggregation; +package com.facebook.presto.spi.function; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlType.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlType.java index 4dcdbd9014f09..17f67bb7dc452 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlType.java @@ -25,4 +25,6 @@ public @interface SqlType { String value() default ""; + + Class nativeContainerType() default Object.class; }