From a091b4dcb5f1bc7bcb4e8d20ace9740ba9478c31 Mon Sep 17 00:00:00 2001 From: Gerlou Shyy Date: Thu, 19 Jul 2018 11:25:15 -0700 Subject: [PATCH 1/5] Move @BlockPosition and @BlockIndex to SPI --- .../operator/aggregation/AggregationImplementation.java | 2 ++ .../presto/operator/TestAnnotationEngineForAggregates.java | 4 ++-- .../presto/operator/aggregation/TestCountNullAggregation.java | 2 ++ .../java/com/facebook/presto/spi/function}/BlockIndex.java | 2 +- .../java/com/facebook/presto/spi/function}/BlockPosition.java | 2 +- 5 files changed, 8 insertions(+), 4 deletions(-) rename {presto-main/src/main/java/com/facebook/presto/operator/aggregation => presto-spi/src/main/java/com/facebook/presto/spi/function}/BlockIndex.java (94%) rename {presto-main/src/main/java/com/facebook/presto/operator/aggregation => presto-spi/src/main/java/com/facebook/presto/spi/function}/BlockPosition.java (94%) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java index 8f68b94935602..3cf13c853b6f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java @@ -26,6 +26,8 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.OutputFunction; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java index 030ab612e6614..3536a75b4fd08 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java @@ -21,8 +21,8 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.AggregationImplementation; import com.facebook.presto.operator.aggregation.AggregationMetadata; -import com.facebook.presto.operator.aggregation.BlockIndex; -import com.facebook.presto.operator.aggregation.BlockPosition; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.operator.aggregation.LazyAccumulatorFactoryBinder; import com.facebook.presto.operator.aggregation.ParametricAggregation; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java index 23947fdef2e7f..b710e282a1e90 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java @@ -19,6 +19,8 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CombineFunction; import com.facebook.presto.spi.function.InputFunction; import com.facebook.presto.spi.function.OutputFunction; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockIndex.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/BlockIndex.java similarity index 94% rename from presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockIndex.java rename to presto-spi/src/main/java/com/facebook/presto/spi/function/BlockIndex.java index 065de6b9928bc..1b5d9bf87b806 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockIndex.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/BlockIndex.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.operator.aggregation; +package com.facebook.presto.spi.function; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockPosition.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/BlockPosition.java similarity index 94% rename from presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockPosition.java rename to presto-spi/src/main/java/com/facebook/presto/spi/function/BlockPosition.java index b4a62b5da6b96..cb060b241473e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/BlockPosition.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/BlockPosition.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.operator.aggregation; +package com.facebook.presto.spi.function; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; From 7877a00b367fc03f7f56304bc98c9aeb532353bd Mon Sep 17 00:00:00 2001 From: Gerlou Shyy Date: Thu, 19 Jul 2018 11:39:23 -0700 Subject: [PATCH 2/5] Rename ScalarImplementation to ParametricScalarImplementation --- .../presto/operator/scalar/ParametricScalar.java | 16 ++++++++-------- ....java => ParametricScalarImplementation.java} | 10 +++++----- .../annotations/ScalarFromAnnotationsParser.java | 6 +++--- ...arametricScalarImplementationValidation.java} | 12 ++++++------ 4 files changed, 22 insertions(+), 22 deletions(-) rename presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/{ScalarImplementation.java => ParametricScalarImplementation.java} (98%) rename presto-main/src/test/java/com/facebook/presto/operator/scalar/{TestScalarImplementationValidation.java => TestParametricScalarImplementationValidation.java} (84%) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index c2fd5cb86d96b..04f250a66ca39 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -18,8 +18,8 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.operator.ParametricImplementationsGroup; -import com.facebook.presto.operator.scalar.annotations.ScalarImplementation; -import com.facebook.presto.operator.scalar.annotations.ScalarImplementation.MethodHandleAndConstructor; +import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation; +import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation.MethodHandleAndConstructor; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.TypeManager; import com.google.common.annotations.VisibleForTesting; @@ -38,12 +38,12 @@ public class ParametricScalar extends SqlScalarFunction { private final ScalarHeader details; - private final ParametricImplementationsGroup implementations; + private final ParametricImplementationsGroup implementations; public ParametricScalar( Signature signature, ScalarHeader details, - ParametricImplementationsGroup implementations) + ParametricImplementationsGroup implementations) { super(signature); this.details = requireNonNull(details); @@ -69,7 +69,7 @@ public String getDescription() } @VisibleForTesting - public ParametricImplementationsGroup getImplementations() + public ParametricImplementationsGroup getImplementations() { return implementations; } @@ -79,7 +79,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in { Signature boundSignature = applyBoundVariables(getSignature(), boundVariables, arity); if (implementations.getExactImplementations().containsKey(boundSignature)) { - ScalarImplementation implementation = implementations.getExactImplementations().get(boundSignature); + ParametricScalarImplementation implementation = implementations.getExactImplementations().get(boundSignature); Optional methodHandleAndConstructor = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); checkCondition(methodHandleAndConstructor.isPresent(), FUNCTION_IMPLEMENTATION_ERROR, String.format("Exact implementation of %s do not match expected java types.", boundSignature.getName())); return new ScalarFunctionImplementation( @@ -91,7 +91,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in } ScalarFunctionImplementation selectedImplementation = null; - for (ScalarImplementation implementation : implementations.getSpecializedImplementations()) { + for (ParametricScalarImplementation implementation : implementations.getSpecializedImplementations()) { Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); if (methodHandle.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); @@ -107,7 +107,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in return selectedImplementation; } - for (ScalarImplementation implementation : implementations.getGenericImplementations()) { + for (ParametricScalarImplementation implementation : implementations.getGenericImplementations()) { Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); if (methodHandle.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java similarity index 98% rename from presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java rename to presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java index 9795884977125..d7f3a71c207b0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java @@ -82,7 +82,7 @@ import static java.lang.reflect.Modifier.isStatic; import static java.util.Objects.requireNonNull; -public class ScalarImplementation +public class ParametricScalarImplementation implements ParametricImplementation { private final Signature signature; @@ -95,7 +95,7 @@ public class ScalarImplementation private final List> argumentNativeContainerTypes; private final Map> specializedTypeParameters; - public ScalarImplementation( + public ParametricScalarImplementation( Signature signature, boolean nullable, List argumentProperties, @@ -427,7 +427,7 @@ private MethodHandle getMethodHandle(Method method) return methodHandle; } - public ScalarImplementation get() + public ParametricScalarImplementation get() { Signature signature = new Signature( functionName, @@ -437,7 +437,7 @@ public ScalarImplementation get() returnType, argumentTypes, false); - return new ScalarImplementation( + return new ParametricScalarImplementation( signature, nullable, argumentProperties, @@ -449,7 +449,7 @@ public ScalarImplementation get() specializedTypeParameters); } - public static ScalarImplementation parseImplementation(String functionName, Method method, Optional> constructor) + public static ParametricScalarImplementation parseImplementation(String functionName, Method method, Optional> constructor) { return new Parser(functionName, method, constructor).get(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java index 92d9dbb9e06a5..b4ec6331d266e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java @@ -93,16 +93,16 @@ private static List findScalarsInFunctionSetClass(Class< private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods scalar, Optional> constructor) { - ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); + ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); ScalarImplementationHeader header = scalar.getHeader(); checkArgument(!header.getName().isEmpty()); for (Method method : scalar.getMethods()) { - ScalarImplementation implementation = ScalarImplementation.Parser.parseImplementation(header.getName(), method, constructor); + ParametricScalarImplementation implementation = ParametricScalarImplementation.Parser.parseImplementation(header.getName(), method, constructor); implementationsBuilder.addImplementation(implementation); } - ParametricImplementationsGroup implementations = implementationsBuilder.build(); + ParametricImplementationsGroup implementations = implementationsBuilder.build(); Signature scalarSignature = implementations.getSignature(); header.getOperatorType().ifPresent(operatorType -> diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarImplementationValidation.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestParametricScalarImplementationValidation.java similarity index 84% rename from presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarImplementationValidation.java rename to presto-main/src/test/java/com/facebook/presto/operator/scalar/TestParametricScalarImplementationValidation.java index b705eaa3c6648..1974a62c26adc 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarImplementationValidation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestParametricScalarImplementationValidation.java @@ -26,15 +26,15 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; -public class TestScalarImplementationValidation +public class TestParametricScalarImplementationValidation { - private static final MethodHandle STATE_FACTORY = methodHandle(TestScalarImplementationValidation.class, "createState"); + private static final MethodHandle STATE_FACTORY = methodHandle(TestParametricScalarImplementationValidation.class, "createState"); @Test public void testConnectorSessionPosition() { // Without cached instance factory - MethodHandle validFunctionMethodHandle = methodHandle(TestScalarImplementationValidation.class, "validConnectorSessionParameterPosition", ConnectorSession.class, long.class, long.class); + MethodHandle validFunctionMethodHandle = methodHandle(TestParametricScalarImplementationValidation.class, "validConnectorSessionParameterPosition", ConnectorSession.class, long.class, long.class); ScalarFunctionImplementation validFunction = new ScalarFunctionImplementation( false, ImmutableList.of( @@ -50,7 +50,7 @@ public void testConnectorSessionPosition() ImmutableList.of( valueTypeArgumentProperty(RETURN_NULL_ON_NULL), valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), - methodHandle(TestScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", long.class, long.class, ConnectorSession.class), + methodHandle(TestParametricScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", long.class, long.class, ConnectorSession.class), false); fail("expected exception"); } @@ -59,7 +59,7 @@ public void testConnectorSessionPosition() } // With cached instance factory - MethodHandle validFunctionWithInstanceFactoryMethodHandle = methodHandle(TestScalarImplementationValidation.class, "validConnectorSessionParameterPosition", Object.class, ConnectorSession.class, long.class, long.class); + MethodHandle validFunctionWithInstanceFactoryMethodHandle = methodHandle(TestParametricScalarImplementationValidation.class, "validConnectorSessionParameterPosition", Object.class, ConnectorSession.class, long.class, long.class); ScalarFunctionImplementation validFunctionWithInstanceFactory = new ScalarFunctionImplementation( false, ImmutableList.of( @@ -76,7 +76,7 @@ public void testConnectorSessionPosition() ImmutableList.of( valueTypeArgumentProperty(RETURN_NULL_ON_NULL), valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), - methodHandle(TestScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", Object.class, long.class, long.class, ConnectorSession.class), + methodHandle(TestParametricScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", Object.class, long.class, long.class, ConnectorSession.class), Optional.of(STATE_FACTORY), false); fail("expected exception"); From 83e67e64bd2e35e6e12905af7ac66233960750d7 Mon Sep 17 00:00:00 2001 From: Gerlou Shyy Date: Thu, 19 Jul 2018 12:01:10 -0700 Subject: [PATCH 3/5] Add type parameter specialization inference based on return type --- .../ParametricScalarImplementation.java | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java index d7f3a71c207b0..fdf9f15648a49 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java @@ -285,6 +285,7 @@ else if (actualReturnType.isPrimitive()) { "Expected type parameter to only contain A-Z and 0-9 (starting with A-Z), but got %s on method [%s]", typeParameter.value(), method); } + inferSpecialization(method, actualReturnType, returnType.value(), nullable); parseArguments(method); this.constructorMethodHandle = getConstructor(method, constructor); @@ -344,13 +345,7 @@ else if (parameterType.isPrimitive() && !hasNullFlag) { checkArgument(!nullableArgument, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName()); } - if (typeParameterNames.contains(type.value()) && !(parameterType == Object.class && nullableArgument)) { - // Infer specialization on this type parameter. We don't do this for @SqlNullable Object because it could match a type like BIGINT - Class specialization = specializedTypeParameters.get(type.value()); - Class nativeParameterType = Primitives.unwrap(parameterType); - checkArgument(specialization == null || specialization.equals(nativeParameterType), "Method [%s] type %s has conflicting specializations %s and %s", method, type.value(), specialization, nativeParameterType); - specializedTypeParameters.put(type.value(), nativeParameterType); - } + inferSpecialization(method, parameterType, type.value(), nullableArgument); argumentNativeContainerTypes.add(parameterType); argumentTypes.add(typeSignature); @@ -377,6 +372,17 @@ else if (parameterType.isPrimitive() && !hasNullFlag) { } } + private void inferSpecialization(Method method, Class parameterType, String typeParameterName, boolean nullable) + { + if (typeParameterNames.contains(typeParameterName) && !(parameterType == Object.class && nullable)) { + // Infer specialization on this type parameter. We don't do this for @SqlNullable Object because it could match a type like BIGINT + Class specialization = specializedTypeParameters.get(typeParameterName); + Class nativeParameterType = Primitives.unwrap(parameterType); + checkArgument(specialization == null || specialization.equals(nativeParameterType), "Method [%s] type %s has conflicting specializations %s and %s", method, typeParameterName, specialization, nativeParameterType); + specializedTypeParameters.put(typeParameterName, nativeParameterType); + } + } + // Find matching constructor, if this is an instance method, and populate constructorDependencies private Optional getConstructor(Method method, Optional> optionalConstructor) { From ec298ada78d275b684667e75a60a8610057d85ca Mon Sep 17 00:00:00 2001 From: Gerlou Shyy Date: Thu, 19 Jul 2018 12:13:07 -0700 Subject: [PATCH 4/5] Add BLOCK_AND_POSITION calling convention to annotation framework --- .../operator/scalar/ParametricScalar.java | 35 +-- .../ParametricScalarImplementation.java | 289 ++++++++++++++---- .../ScalarFromAnnotationsParser.java | 13 +- .../TestBlockAndPositionNullConvention.java | 163 ++++++---- .../facebook/presto/spi/function/SqlType.java | 2 + 5 files changed, 352 insertions(+), 150 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index 04f250a66ca39..269bdf8f3cfd2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -19,7 +19,6 @@ import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation; -import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation.MethodHandleAndConstructor; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.TypeManager; import com.google.common.annotations.VisibleForTesting; @@ -28,7 +27,6 @@ import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; import static com.facebook.presto.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_IMPLEMENTATION; -import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static com.facebook.presto.util.Failures.checkCondition; import static java.lang.String.format; @@ -80,43 +78,26 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Signature boundSignature = applyBoundVariables(getSignature(), boundVariables, arity); if (implementations.getExactImplementations().containsKey(boundSignature)) { ParametricScalarImplementation implementation = implementations.getExactImplementations().get(boundSignature); - Optional methodHandleAndConstructor = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); - checkCondition(methodHandleAndConstructor.isPresent(), FUNCTION_IMPLEMENTATION_ERROR, String.format("Exact implementation of %s do not match expected java types.", boundSignature.getName())); - return new ScalarFunctionImplementation( - implementation.isNullable(), - implementation.getArgumentProperties(), - methodHandleAndConstructor.get().getMethodHandle(), - methodHandleAndConstructor.get().getConstructor(), - isDeterministic()); + Optional scalarFunctionImplementation = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry, isDeterministic()); + return scalarFunctionImplementation.get(); } ScalarFunctionImplementation selectedImplementation = null; for (ParametricScalarImplementation implementation : implementations.getSpecializedImplementations()) { - Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); - if (methodHandle.isPresent()) { + Optional scalarFunctionImplementation = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry, isDeterministic()); + if (scalarFunctionImplementation.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); - selectedImplementation = new ScalarFunctionImplementation( - implementation.isNullable(), - implementation.getArgumentProperties(), - methodHandle.get().getMethodHandle(), - methodHandle.get().getConstructor(), - isDeterministic()); + selectedImplementation = scalarFunctionImplementation.get(); } } if (selectedImplementation != null) { return selectedImplementation; } - for (ParametricScalarImplementation implementation : implementations.getGenericImplementations()) { - Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); - if (methodHandle.isPresent()) { + Optional scalarFunctionImplementation = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry, isDeterministic()); + if (scalarFunctionImplementation.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); - selectedImplementation = new ScalarFunctionImplementation( - implementation.isNullable(), - implementation.getArgumentProperties(), - methodHandle.get().getMethodHandle(), - methodHandle.get().getConstructor(), - isDeterministic()); + selectedImplementation = scalarFunctionImplementation.get(); } } if (selectedImplementation != null) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java index fdf9f15648a49..539b6e2b77e9e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java @@ -23,7 +23,10 @@ import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; @@ -44,9 +47,11 @@ import java.lang.reflect.Method; import java.lang.reflect.Parameter; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Stream; @@ -67,6 +72,7 @@ import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentType.FUNCTION_TYPE; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentType.VALUE_TYPE; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.BLOCK_AND_POSITION; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_BOXED_TYPE; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_NULL_FLAG; @@ -81,75 +87,79 @@ import static java.lang.invoke.MethodHandles.permuteArguments; import static java.lang.reflect.Modifier.isStatic; import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertTrue; public class ParametricScalarImplementation implements ParametricImplementation { private final Signature signature; - private final boolean nullable; - private final List argumentProperties; - private final MethodHandle methodHandle; private final List dependencies; - private final Optional constructor; private final List constructorDependencies; private final List> argumentNativeContainerTypes; private final Map> specializedTypeParameters; + private final Class returnNativeContainerType; + private final SpecializedSignature specializedSignature; + private final List choices; - public ParametricScalarImplementation( + private ParametricScalarImplementation( Signature signature, - boolean nullable, - List argumentProperties, - MethodHandle methodHandle, List dependencies, - Optional constructor, List constructorDependencies, List> argumentNativeContainerTypes, - Map> specializedTypeParameters) + Map> specializedTypeParameters, + List choices, + Class returnContainerType, + SpecializedSignature specializedSignature) { this.signature = requireNonNull(signature, "signature is null"); - this.nullable = nullable; - this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null")); - this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null")); - this.constructor = requireNonNull(constructor, "constructor is null"); this.constructorDependencies = ImmutableList.copyOf(requireNonNull(constructorDependencies, "constructorDependencies is null")); this.argumentNativeContainerTypes = ImmutableList.copyOf(requireNonNull(argumentNativeContainerTypes, "argumentNativeContainerTypes is null")); this.specializedTypeParameters = ImmutableMap.copyOf(requireNonNull(specializedTypeParameters, "specializedTypeParameters is null")); + this.choices = choices; + this.returnNativeContainerType = returnContainerType; + this.specializedSignature = specializedSignature; } - public Optional specialize(Signature boundSignature, BoundVariables boundVariables, TypeManager typeManager, FunctionRegistry functionRegistry) + public Optional specialize(Signature boundSignature, BoundVariables boundVariables, TypeManager typeManager, FunctionRegistry functionRegistry, boolean isDeterministic) { + List implementationChoices = new ArrayList(); for (Map.Entry> entry : specializedTypeParameters.entrySet()) { if (!entry.getValue().isAssignableFrom(boundVariables.getTypeVariable(entry.getKey()).getJavaType())) { return Optional.empty(); } } - Class returnContainerType = getNullAwareReturnType(typeManager.getType(boundSignature.getReturnType()).getJavaType(), nullable); - if (!returnContainerType.equals(methodHandle.type().returnType())) { + + if (!returnNativeContainerType.equals(typeManager.getType(boundSignature.getReturnType()).getJavaType())) { return Optional.empty(); } - for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) { - ScalarFunctionImplementation.ArgumentProperty argumentProperty = argumentProperties.get(i); - if (boundSignature.getArgumentTypes().get(i).getBase().equals(FunctionType.NAME)) { - if (argumentProperty.getArgumentType() != FUNCTION_TYPE) { - return Optional.empty(); - } - } - else { - if (argumentProperty.getArgumentType() != VALUE_TYPE) { - return Optional.empty(); + for (ParametricScalarImplementationChoice choice : choices) { + for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) { + ScalarFunctionImplementation.ArgumentProperty argumentProperty = choice.getArgumentProperties().get(i); + if (boundSignature.getArgumentTypes().get(i).getBase().equals(FunctionType.NAME)) { + if (argumentProperty.getArgumentType() != FUNCTION_TYPE) { + return Optional.empty(); + } } + else { + if (argumentProperty.getArgumentType() != VALUE_TYPE) { + return Optional.empty(); + } - Class argumentType = typeManager.getType(boundSignature.getArgumentTypes().get(i)).getJavaType(); - Class argumentContainerType = getNullAwareContainerType(argumentType, argumentProperty.getNullConvention()); - if (!argumentNativeContainerTypes.get(i).isAssignableFrom(argumentContainerType)) { - return Optional.empty(); + Class argumentType = typeManager.getType(boundSignature.getArgumentTypes().get(i)).getJavaType(); + Class argumentContainerType = getNullAwareContainerType(argumentType, argumentProperty.getNullConvention()); + + if (!choice.getArgumentNativeContainerTypes().get(i).isAssignableFrom(argumentContainerType)) { + return Optional.empty(); + } } } + MethodHandle boundMethodHandle = bindDependencies(choice.methodHandle, dependencies, boundVariables, typeManager, functionRegistry); + Optional boundConstructor = choice.constructor.map(handle -> bindDependencies(handle, constructorDependencies, boundVariables, typeManager, functionRegistry)); + + implementationChoices.add(new ScalarImplementationChoice(choice.nullable, choice.argumentProperties, boundMethodHandle, boundConstructor)); } - MethodHandle boundMethodHandle = bindDependencies(this.methodHandle, dependencies, boundVariables, typeManager, functionRegistry); - Optional boundConstructor = this.constructor.map(handle -> bindDependencies(handle, constructorDependencies, boundVariables, typeManager, functionRegistry)); - return Optional.of(new MethodHandleAndConstructor(boundMethodHandle, boundConstructor)); + return Optional.of(new ScalarFunctionImplementation(implementationChoices, isDeterministic)); } private static Class getNullAwareReturnType(Class clazz, boolean nullable) @@ -183,30 +193,87 @@ public Signature getSignature() return signature; } - public boolean isNullable() + public List getDependencies() { - return nullable; + return dependencies; } - public List getArgumentProperties() + @VisibleForTesting + public List getConstructorDependencies() { - return argumentProperties; + return constructorDependencies; } - public MethodHandle getMethodHandle() + public void updateChoices(ParametricScalarImplementation implementation) { - return methodHandle; + for (ParametricScalarImplementationChoice choice : implementation.choices) { + this.choices.add(choice); + } + Collections.sort(choices, ParametricScalarImplementationChoice::compareTo); } - public List getDependencies() + public SpecializedSignature getSpecializedSignature() { - return dependencies; + return specializedSignature; } - @VisibleForTesting - public List getConstructorDependencies() + public static final class ParametricScalarImplementationChoice + implements Comparable { - return constructorDependencies; + private final boolean nullable; + private final List argumentProperties; + private final MethodHandle methodHandle; + private final Optional constructor; + private final Map> specializedTypeParameters; + private final List> argumentNativeContainerTypes; + private final int numberOfBlockPositionArguments; + + private ParametricScalarImplementationChoice( + boolean nullable, + List argumentProperties, + MethodHandle methodHandle, + Optional constructor, + Map> specializedTypeParameters, + List> argumentNativeContainerTypes, + int numberOfBlockPositionArguments) + { + this.nullable = nullable; + this.argumentProperties = argumentProperties; + this.methodHandle = methodHandle; + this.constructor = requireNonNull(constructor, "constructor is null"); + this.specializedTypeParameters = specializedTypeParameters; + this.argumentNativeContainerTypes = argumentNativeContainerTypes; + this.numberOfBlockPositionArguments = numberOfBlockPositionArguments; + } + + public boolean isNullable() + { + return nullable; + } + + public List getArgumentProperties() + { + return argumentProperties; + } + + public MethodHandle getMethodHandle() + { + return methodHandle; + } + + public List> getArgumentNativeContainerTypes() + { + return argumentNativeContainerTypes; + } + + @Override + public int compareTo(ParametricScalarImplementationChoice choice) + { + if (choice.numberOfBlockPositionArguments < this.numberOfBlockPositionArguments) { + return 1; + } + return -1; + } } public static final class MethodHandleAndConstructor @@ -231,6 +298,48 @@ public Optional getConstructor() } } + public static final class SpecializedSignature + { + private final Signature signature; + private final List> argumentNativeContainerTypes; + private final Map> specializedTypeParameters; + private final Class returnNativeContainerType; + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SpecializedSignature that = (SpecializedSignature) o; + return Objects.equals(signature, that.signature) && + Objects.equals(argumentNativeContainerTypes, that.argumentNativeContainerTypes) && + Objects.equals(specializedTypeParameters, that.specializedTypeParameters) && + Objects.equals(returnNativeContainerType, that.returnNativeContainerType); + } + + @Override + public int hashCode() + { + return Objects.hash(signature, argumentNativeContainerTypes, specializedTypeParameters, returnNativeContainerType); + } + + private SpecializedSignature( + Signature signature, + List> argumentNativeContainerTypes, + Map> specializedTypeParameters, + Class returnNativeContainerType) + { + this.signature = signature; + this.argumentNativeContainerTypes = argumentNativeContainerTypes; + this.specializedTypeParameters = specializedTypeParameters; + this.returnNativeContainerType = returnNativeContainerType; + } + } + public static final class Parser { private final String functionName; @@ -248,11 +357,17 @@ public static final class Parser private final Optional constructorMethodHandle; private final List constructorDependencies = new ArrayList<>(); private final List longVariableConstraints; + private final Class returnNativeContainerType; + private NullConventionFlag nullConventionFlag; + private int numberOfBlockPositionArguments; + + private final List choices = new ArrayList(); private Parser(String functionName, Method method, Optional> constructor) { this.functionName = requireNonNull(functionName, "functionName is null"); this.nullable = method.getAnnotation(SqlNullable.class) != null; + ParametricScalarImplementation.NullConventionFlag nullConventionFlag = nullable ? NullConventionFlag.NULLABLE_ARGUMENT : NullConventionFlag.NOT_NULLABLE_ARGUMENT; checkArgument(nullable || !containsLegacyNullable(method.getAnnotations()), "Method [%s] is annotated with @Nullable but not @SqlNullable", method); Stream.of(method.getAnnotationsByType(TypeParameter.class)) @@ -268,6 +383,8 @@ private Parser(String functionName, Method method, Optional> cons this.returnType = parseTypeSignature(returnType.value(), literalParameters); Class actualReturnType = method.getReturnType(); + this.returnNativeContainerType = Primitives.unwrap(actualReturnType); + if (Primitives.isWrapperType(actualReturnType)) { checkArgument(nullable, "Method [%s] has wrapper return type %s but is missing @SqlNullable", method, actualReturnType.getSimpleName()); } @@ -285,12 +402,15 @@ else if (actualReturnType.isPrimitive()) { "Expected type parameter to only contain A-Z and 0-9 (starting with A-Z), but got %s on method [%s]", typeParameter.value(), method); } - inferSpecialization(method, actualReturnType, returnType.value(), nullable); + inferSpecialization(method, actualReturnType, returnType.value(), nullConventionFlag); parseArguments(method); this.constructorMethodHandle = getConstructor(method, constructor); this.methodHandle = getMethodHandle(method); + + ParametricScalarImplementationChoice choice = new ParametricScalarImplementationChoice(nullable, argumentProperties, methodHandle, constructorMethodHandle, specializedTypeParameters, argumentNativeContainerTypes, numberOfBlockPositionArguments); + choices.add(choice); } private void parseArguments(Method method) @@ -320,10 +440,22 @@ private void parseArguments(Method method) .findFirst() .orElseThrow(() -> new IllegalArgumentException(format("Method [%s] is missing @SqlType annotation for parameter", method))); TypeSignature typeSignature = parseTypeSignature(type.value(), literalParameters); - boolean nullableArgument = Stream.of(annotations).anyMatch(SqlNullable.class::isInstance); - checkArgument(nullableArgument || !containsLegacyNullable(annotations), "Method [%s] has parameter annotated with @Nullable but not @SqlNullable", method); + if (Stream.of(annotations).anyMatch(SqlNullable.class::isInstance)) { + nullConventionFlag = NullConventionFlag.NULLABLE_ARGUMENT; + } + else { + nullConventionFlag = NullConventionFlag.NOT_NULLABLE_ARGUMENT; + } + checkArgument(nullConventionFlag == NullConventionFlag.NULLABLE_ARGUMENT || !containsLegacyNullable(annotations), "Method [%s] has parameter annotated with @Nullable but not @SqlNullable", method); + + if (Stream.of(annotations).anyMatch(BlockPosition.class::isInstance)) { + nullConventionFlag = NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION; + assertTrue(method.getParameterCount() > (i + 1)); + Annotation[] parameterAnnotations = method.getParameterAnnotations()[i + 1]; + assertTrue(Stream.of(parameterAnnotations).anyMatch(BlockIndex.class::isInstance)); + numberOfBlockPositionArguments++; + } - boolean hasNullFlag = false; if (method.getParameterCount() > (i + 1)) { Annotation[] parameterAnnotations = method.getParameterAnnotations()[i + 1]; if (Stream.of(parameterAnnotations).anyMatch(IsNull.class::isInstance)) { @@ -333,24 +465,30 @@ private void parseArguments(Method method) checkArgument(isNullType == boolean.class, "Method [%s] has non-boolean parameter with @IsNull", method); checkArgument((parameterType == Void.class) || !Primitives.isWrapperType(parameterType), "Method [%s] uses @IsNull following a parameter with boxed primitive type: %s", method, parameterType.getSimpleName()); - nullableArgument = true; - hasNullFlag = true; + nullConventionFlag = NullConventionFlag.HAS_NULL_FLAG; } } if (Primitives.isWrapperType(parameterType)) { - checkArgument(nullableArgument, "Method [%s] has parameter with wrapper type %s that is missing @SqlNullable", method, parameterType.getSimpleName()); + checkArgument(nullConventionFlag != NullConventionFlag.NOT_NULLABLE_ARGUMENT, "Method [%s] has parameter with wrapper type %s that is missing @SqlNullable", method, parameterType.getSimpleName()); } - else if (parameterType.isPrimitive() && !hasNullFlag) { - checkArgument(!nullableArgument, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName()); + else if (parameterType.isPrimitive() && nullConventionFlag != NullConventionFlag.HAS_NULL_FLAG) { + checkArgument(nullConventionFlag == NullConventionFlag.NOT_NULLABLE_ARGUMENT, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName()); } - inferSpecialization(method, parameterType, type.value(), nullableArgument); - argumentNativeContainerTypes.add(parameterType); + inferSpecialization(method, parameterType, type.value(), nullConventionFlag); + + if (nullConventionFlag == NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION) { + argumentNativeContainerTypes.add(type.nativeContainerType()); + } + else { + checkArgument(type.nativeContainerType().equals(Object.class), "Native container type for parameter needs to be an Object"); + argumentNativeContainerTypes.add(parameterType); + } argumentTypes.add(typeSignature); - if (hasNullFlag) { - // skip @IsNull parameter + if (nullConventionFlag == NullConventionFlag.HAS_NULL_FLAG || nullConventionFlag == NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION) { + // skip @IsNull and @BlockIndex parameter i++; } @@ -360,11 +498,14 @@ else if (parameterType.isPrimitive() && !hasNullFlag) { } else { NullConvention nullConvention; - if (!nullableArgument) { + if (nullConventionFlag == NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION) { + nullConvention = BLOCK_AND_POSITION; + } + else if (nullConventionFlag == NullConventionFlag.NOT_NULLABLE_ARGUMENT) { nullConvention = RETURN_NULL_ON_NULL; } else { - nullConvention = hasNullFlag ? USE_NULL_FLAG : USE_BOXED_TYPE; + nullConvention = nullConventionFlag == NullConventionFlag.HAS_NULL_FLAG ? USE_NULL_FLAG : USE_BOXED_TYPE; } argumentProperties.add(valueTypeArgumentProperty(nullConvention)); } @@ -372,9 +513,9 @@ else if (parameterType.isPrimitive() && !hasNullFlag) { } } - private void inferSpecialization(Method method, Class parameterType, String typeParameterName, boolean nullable) + private void inferSpecialization(Method method, Class parameterType, String typeParameterName, NullConventionFlag nullableArgument) { - if (typeParameterNames.contains(typeParameterName) && !(parameterType == Object.class && nullable)) { + if (typeParameterNames.contains(typeParameterName) && !(parameterType == Object.class && nullableArgument == NullConventionFlag.NULLABLE_ARGUMENT) && nullableArgument != NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION) { // Infer specialization on this type parameter. We don't do this for @SqlNullable Object because it could match a type like BIGINT Class specialization = specializedTypeParameters.get(typeParameterName); Class nativeParameterType = Primitives.unwrap(parameterType); @@ -443,16 +584,22 @@ public ParametricScalarImplementation get() returnType, argumentTypes, false); + + SpecializedSignature specializedSignature = new SpecializedSignature( + signature, + argumentNativeContainerTypes, + specializedTypeParameters, + returnNativeContainerType); + return new ParametricScalarImplementation( signature, - nullable, - argumentProperties, - methodHandle, dependencies, - constructorMethodHandle, constructorDependencies, argumentNativeContainerTypes, - specializedTypeParameters); + specializedTypeParameters, + choices, + returnNativeContainerType, + specializedSignature); } public static ParametricScalarImplementation parseImplementation(String functionName, Method method, Optional> constructor) @@ -460,4 +607,12 @@ public static ParametricScalarImplementation parseImplementation(String function return new Parser(functionName, method, constructor).get(); } } + + public enum NullConventionFlag + { + NULLABLE_ARGUMENT, + HAS_NULL_FLAG, + NOT_NULLABLE_ARGUMENT, + IS_BLOCK_AND_POSITION_CONVENTION; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java index b4ec6331d266e..1e6ea8b1e2ecc 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java @@ -18,6 +18,7 @@ import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.annotations.FunctionsParserHelper; import com.facebook.presto.operator.scalar.ParametricScalar; +import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation.SpecializedSignature; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; @@ -26,7 +27,9 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -97,9 +100,17 @@ private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods sc ScalarImplementationHeader header = scalar.getHeader(); checkArgument(!header.getName().isEmpty()); + Map signatures = new HashMap(); for (Method method : scalar.getMethods()) { ParametricScalarImplementation implementation = ParametricScalarImplementation.Parser.parseImplementation(header.getName(), method, constructor); - implementationsBuilder.addImplementation(implementation); + if (!signatures.containsKey(implementation.getSpecializedSignature())) { + signatures.put(implementation.getSpecializedSignature(), implementation); + implementationsBuilder.addImplementation(implementation); + } + else { + ParametricScalarImplementation currentImplementation = signatures.get(implementation.getSpecializedSignature()); + currentImplementation.updateChoices(implementation); + } } ParametricImplementationsGroup implementations = implementationsBuilder.build(); diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestBlockAndPositionNullConvention.java b/presto-main/src/test/java/com/facebook/presto/type/TestBlockAndPositionNullConvention.java index 236c646e59a2e..4e3fdf0723986 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestBlockAndPositionNullConvention.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestBlockAndPositionNullConvention.java @@ -13,30 +13,27 @@ */ package com.facebook.presto.type; -import com.facebook.presto.metadata.BoundVariables; -import com.facebook.presto.metadata.FunctionKind; -import com.facebook.presto.metadata.FunctionRegistry; -import com.facebook.presto.metadata.Signature; -import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.operator.scalar.AbstractTestFunctions; -import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; -import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import com.facebook.presto.spi.block.Block; -import com.facebook.presto.spi.type.TypeManager; -import com.google.common.collect.ImmutableList; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import io.airlift.slice.Slice; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import java.lang.invoke.MethodHandle; -import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; -import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; -import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.BLOCK_AND_POSITION; -import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.type.TestBlockAndPositionNullConvention.FunctionWithBlockAndPositionConvention.BLOCK_AND_POSITION_CONVENTION; -import static com.facebook.presto.util.Reflection.methodHandle; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; public class TestBlockAndPositionNullConvention @@ -45,76 +42,132 @@ public class TestBlockAndPositionNullConvention @BeforeClass public void setUp() { - registerScalarFunction(BLOCK_AND_POSITION_CONVENTION); + registerParametricScalar(FunctionWithBlockAndPositionConvention.class); } @Test public void testBlockPosition() { - assertFunction("identityFunction(9876543210)", BIGINT, 9876543210L); - assertFunction("identityFunction(bound_long)", BIGINT, 1234L); - assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPosition.get()); + assertFunction("test_block_position(9876543210)", BIGINT, 9876543210L); + assertFalse(FunctionWithBlockAndPositionConvention.hitBlockPositionBigint.get()); + + assertFunction("test_block_position(bound_long)", BIGINT, 1234L); + assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPositionBigint.get()); + + assertFunction("test_block_position(3.0E0)", DOUBLE, 3.0); + assertFalse(FunctionWithBlockAndPositionConvention.hitBlockPositionDouble.get()); + + assertFunction("test_block_position(bound_double)", DOUBLE, 12.34); + assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPositionDouble.get()); + + assertFunction("test_block_position(bound_string)", VARCHAR, "hello"); + assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPositionSlice.get()); + + // TODO: add adaptations so these will pass + //assertFunction("test_block_position(null)", UNKNOWN, null); + //assertFalse(FunctionWithBlockAndPositionConvention.hitBlockPositionObject.get()); + + assertFunction("test_block_position(false)", BOOLEAN, false); + assertFalse(FunctionWithBlockAndPositionConvention.hitBlockPositionBoolean.get()); + + assertFunction("test_block_position(bound_boolean)", BOOLEAN, true); + assertTrue(FunctionWithBlockAndPositionConvention.hitBlockPositionBoolean.get()); } + @ScalarFunction("test_block_position") public static class FunctionWithBlockAndPositionConvention - extends SqlScalarFunction { - private static final AtomicBoolean hitBlockPosition = new AtomicBoolean(); - public static final FunctionWithBlockAndPositionConvention BLOCK_AND_POSITION_CONVENTION = new FunctionWithBlockAndPositionConvention(); + private static final AtomicBoolean hitBlockPositionBigint = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionDouble = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionSlice = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionBoolean = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionObject = new AtomicBoolean(); - private static final MethodHandle METHOD_HANDLE_BLOCK_AND_POSITION = methodHandle(FunctionWithBlockAndPositionConvention.class, "getBlockPosition", Block.class, int.class); - private static final MethodHandle METHOD_HANDLE_NULL_ON_NULL = methodHandle(FunctionWithBlockAndPositionConvention.class, "getLong", long.class); + /* + // generic implementations + // these will not work right now because MethodHandle is not properly adapted - protected FunctionWithBlockAndPositionConvention() + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Object generic(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Object object) { - super(new Signature("identityFunction", FunctionKind.SCALAR, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); + return object; } - @Override - public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Object generic(@TypeParameter("E") Type type, @BlockPosition @SqlType("E") Block block, @BlockIndex int position) { - return new ScalarFunctionImplementation( - ImmutableList.of( - new ScalarImplementationChoice( - false, - ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), - METHOD_HANDLE_NULL_ON_NULL, - Optional.empty()), - new ScalarImplementationChoice( - false, - ImmutableList.of(valueTypeArgumentProperty(BLOCK_AND_POSITION)), - METHOD_HANDLE_BLOCK_AND_POSITION, - Optional.empty())), - isDeterministic()); + hitBlockPositionObject.set(true); + return TypeUtils.readNativeValue(type, block, position); } + */ + + // specialized - public static long getBlockPosition(Block block, int position) + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Slice slice) { - hitBlockPosition.set(true); - return BIGINT.getLong(block, position); + return slice; } - public static long getLong(long number) + @TypeParameter("E") + @SqlType("E") + public static Slice specializedSlice(@TypeParameter("E") Type type, @BlockPosition @SqlType(value = "E", nativeContainerType = Slice.class) Block block, @BlockIndex int position) + { + hitBlockPositionSlice.set(true); + return type.getSlice(block, position); + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlType("E") boolean bool) + { + return bool; + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @BlockPosition @SqlType(value = "E", nativeContainerType = boolean.class) Block block, @BlockIndex int position) + { + hitBlockPositionBoolean.set(true); + return type.getBoolean(block, position); + } + + // exact + + @SqlType(StandardTypes.BIGINT) + public static long getLong(@SqlType(StandardTypes.BIGINT) long number) { return number; } - @Override - public boolean isDeterministic() + @SqlType(StandardTypes.BIGINT) + public static long getBlockPosition(@BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) Block block, @BlockIndex int position) { - return true; + hitBlockPositionBigint.set(true); + return BIGINT.getLong(block, position); } - @Override - public boolean isHidden() + @SqlType(StandardTypes.DOUBLE) + @SqlNullable + public static Double getDouble(@SqlType(StandardTypes.DOUBLE) double number) { - return false; + return number; } - @Override - public String getDescription() + @SqlType(StandardTypes.DOUBLE) + @SqlNullable + public static Double getDouble(@BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) Block block, @BlockIndex int position) { - return ""; + hitBlockPositionDouble.set(true); + return DOUBLE.getDouble(block, position); } } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlType.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlType.java index 4dcdbd9014f09..17f67bb7dc452 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlType.java @@ -25,4 +25,6 @@ public @interface SqlType { String value() default ""; + + Class nativeContainerType() default Object.class; } From 429b3f890e98f25e8acf6cffa65495adf4989831 Mon Sep 17 00:00:00 2001 From: Gerlou Shyy Date: Thu, 19 Jul 2018 12:14:00 -0700 Subject: [PATCH 5/5] Migrate IS_DISTINCT_FROM operator to use BLOCK_AND_POSITION convention --- .../presto/metadata/FunctionRegistry.java | 10 ++++ .../facebook/presto/type/BigintOperators.java | 45 ++++++++++---- .../presto/type/BooleanOperators.java | 47 +++++++++++---- .../facebook/presto/type/DateOperators.java | 45 ++++++++++---- .../facebook/presto/type/DoubleOperators.java | 54 ++++++++++++----- .../presto/type/IntegerOperators.java | 60 +++++++++++-------- .../facebook/presto/type/RealOperators.java | 55 ++++++++++++----- .../presto/type/SmallintOperators.java | 45 ++++++++++---- .../presto/type/TinyintOperators.java | 45 ++++++++++---- .../presto/type/VarbinaryOperators.java | 45 ++++++++++---- .../presto/type/VarcharOperators.java | 48 +++++++++++---- .../TestAnnotationEngineForAggregates.java | 4 +- .../operator/scalar/FunctionAssertions.java | 22 ++++++- .../presto/sql/TestExpressionInterpreter.java | 2 +- .../sql/gen/TestExpressionCompiler.java | 2 + .../presto/type/TestBigintOperators.java | 1 + .../presto/type/TestDoubleOperators.java | 2 + .../presto/type/TestIntegerOperators.java | 2 + .../presto/type/TestRealOperators.java | 1 + .../presto/type/TestVarcharOperators.java | 1 + 20 files changed, 394 insertions(+), 142 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java index 0f142affac350..c0e582ea2def3 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java @@ -480,15 +480,25 @@ public FunctionRegistry(TypeManager typeManager, BlockEncodingSerde blockEncodin .scalars(HyperLogLogFunctions.class) .scalars(UnknownOperators.class) .scalars(BooleanOperators.class) + .scalar(BooleanOperators.BooleanIsDistinctFrom.class) .scalars(BigintOperators.class) + .scalar(BigintOperators.BigIntIsDistinctFrom.class) .scalars(IntegerOperators.class) + .scalar(IntegerOperators.IntegerIsDistinctFrom.class) .scalars(SmallintOperators.class) + .scalar(SmallintOperators.SmallIntIsDistinctFrom.class) .scalars(TinyintOperators.class) + .scalar(TinyintOperators.TinyIntIsDistinctFrom.class) .scalars(DoubleOperators.class) + .scalar(DoubleOperators.DoubleIsDistinctFrom.class) .scalars(RealOperators.class) + .scalar(RealOperators.RealIsDistinctFrom.class) .scalars(VarcharOperators.class) + .scalar(VarcharOperators.VarcharIsDistinctFrom.class) .scalars(VarbinaryOperators.class) + .scalar(VarbinaryOperators.VarbinaryIsDistinctFrom.class) .scalars(DateOperators.class) + .scalar(DateOperators.DateIsDistinctFrom.class) .scalars(TimeOperators.class) .scalars(TimestampOperators.class) .scalars(IntervalDayTimeOperators.class) diff --git a/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java b/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java index 63132870586fc..a1a98140d164e 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -47,6 +50,7 @@ import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToRawIntBits; import static java.lang.Math.toIntExact; @@ -277,20 +281,39 @@ public static long hashCode(@SqlType(StandardTypes.BIGINT) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.BIGINT) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.BIGINT) long right, - @IsNull boolean rightNull) + public static class BigIntIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.BIGINT) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.BIGINT) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !BIGINT.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(XX_HASH_64) diff --git a/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java b/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java index 7f1dced74a8a0..cd1df1b271118 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.type; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; @@ -33,6 +36,7 @@ import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static java.lang.Float.floatToRawIntBits; import static java.nio.charset.StandardCharsets.US_ASCII; @@ -166,19 +170,38 @@ public static boolean not(@SqlType(StandardTypes.BOOLEAN) boolean value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.BOOLEAN) boolean left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.BOOLEAN) boolean right, - @IsNull boolean rightNull) - { - if (leftNull != rightNull) { - return true; + public static class BooleanIsDistinctFrom + { + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.BOOLEAN) boolean left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.BOOLEAN) boolean right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.BOOLEAN, nativeContainerType = boolean.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.BOOLEAN, nativeContainerType = boolean.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !BOOLEAN.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/DateOperators.java b/presto-main/src/main/java/com/facebook/presto/type/DateOperators.java index ba296079d5507..e5e46d5b78725 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/DateOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/DateOperators.java @@ -163,20 +163,41 @@ public static long hashCode(@SqlType(StandardTypes.DATE) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.DATE) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.DATE) long right, - @IsNull boolean rightNull) - { - if (leftNull != rightNull) { - return true; + public static class DateIsDistinctFrom + { + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.DATE) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.DATE) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + /* + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(StandardTypes.DATE) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(StandardTypes.DATE) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return DATE.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); + */ } @ScalarOperator(INDETERMINATE) diff --git a/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java b/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java index 3223ae8bf881c..4212d82e23226 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java @@ -15,6 +15,9 @@ import com.facebook.presto.operator.scalar.MathFunctions; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -49,6 +52,7 @@ import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.google.common.base.Preconditions.checkState; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Double.doubleToLongBits; @@ -315,23 +319,45 @@ private static long saturatedFloorCastToLong(double value, long minValue, double } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.DOUBLE) double left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.DOUBLE) double right, - @IsNull boolean rightNull) + public static class DoubleIsDistinctFrom { - if (leftNull != rightNull) { - return true; - } - if (leftNull) { - return false; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.DOUBLE) double left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.DOUBLE) double right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + if (Double.isNaN(left) && Double.isNaN(right)) { + return false; + } + return notEqual(left, right); } - if (Double.isNaN(left) && Double.isNaN(right)) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + if (Double.isNaN(DOUBLE.getDouble(left, leftPosition)) && Double.isNaN(DOUBLE.getDouble(right, rightPosition))) { + return false; + } + return !DOUBLE.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(XX_HASH_64) diff --git a/presto-main/src/main/java/com/facebook/presto/type/IntegerOperators.java b/presto-main/src/main/java/com/facebook/presto/type/IntegerOperators.java index da41e0a7c6232..4d50cedf03428 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/IntegerOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/IntegerOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -43,9 +46,9 @@ import static com.facebook.presto.spi.function.OperatorType.MULTIPLY; import static com.facebook.presto.spi.function.OperatorType.NEGATION; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; -import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToRawIntBits; import static java.lang.String.format; @@ -246,34 +249,39 @@ public static long hashCode(@SqlType(StandardTypes.INTEGER) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.INTEGER) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.INTEGER) long right, - @IsNull boolean rightNull) + public static class IntegerIsDistinctFrom { - if (leftNull != rightNull) { - return true; - } - if (leftNull) { - return false; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.INTEGER) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.INTEGER) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - return notEqual(left, right); - } - - @ScalarOperator(SATURATED_FLOOR_CAST) - @SqlType(StandardTypes.SMALLINT) - public static long saturatedFloorCastToSmallint(@SqlType(StandardTypes.INTEGER) long value) - { - return Shorts.saturatedCast(value); - } - @ScalarOperator(SATURATED_FLOOR_CAST) - @SqlType(StandardTypes.TINYINT) - public static long saturatedFloorCastToTinyint(@SqlType(StandardTypes.INTEGER) long value) - { - return SignedBytes.saturatedCast(value); + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.INTEGER, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.INTEGER, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !INTEGER.equalTo(left, leftPosition, right, rightPosition); + } } @ScalarOperator(INDETERMINATE) diff --git a/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java b/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java index dc05445801701..45aa31f23db9c 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java @@ -15,6 +15,9 @@ import com.facebook.presto.operator.scalar.MathFunctions; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -47,6 +50,7 @@ import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.RealType.REAL; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToIntBits; import static java.lang.Float.floatToRawIntBits; @@ -241,25 +245,44 @@ public static boolean castToBoolean(@SqlType(StandardTypes.REAL) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.REAL) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.REAL) long right, - @IsNull boolean rightNull) + public static class RealIsDistinctFrom { - if (leftNull != rightNull) { - return true; - } - if (leftNull) { - return false; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.REAL) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.REAL) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + float leftFloat = intBitsToFloat((int) left); + float rightFloat = intBitsToFloat((int) right); + if (Float.isNaN(leftFloat) && Float.isNaN(rightFloat)) { + return false; + } + return notEqual(left, right); } - float leftFloat = intBitsToFloat((int) left); - float rightFloat = intBitsToFloat((int) right); - if (Float.isNaN(leftFloat) && Float.isNaN(rightFloat)) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.REAL, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.REAL, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return REAL.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(SATURATED_FLOOR_CAST) diff --git a/presto-main/src/main/java/com/facebook/presto/type/SmallintOperators.java b/presto-main/src/main/java/com/facebook/presto/type/SmallintOperators.java index 882653c20c5f6..83c0655b7b84f 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/SmallintOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/SmallintOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -46,6 +49,7 @@ import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.SmallintType.SMALLINT; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToRawIntBits; import static java.lang.String.format; @@ -241,20 +245,39 @@ public static long hashCode(@SqlType(StandardTypes.SMALLINT) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.SMALLINT) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.SMALLINT) long right, - @IsNull boolean rightNull) + public static class SmallIntIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.SMALLINT) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.SMALLINT) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.SMALLINT, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.SMALLINT, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !SMALLINT.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(SATURATED_FLOOR_CAST) diff --git a/presto-main/src/main/java/com/facebook/presto/type/TinyintOperators.java b/presto-main/src/main/java/com/facebook/presto/type/TinyintOperators.java index b2ac0d7cf082a..26f343bcc724d 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/TinyintOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/TinyintOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -44,6 +47,7 @@ import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToRawIntBits; import static java.lang.String.format; @@ -241,20 +245,39 @@ public static long xxHash64(@SqlType(StandardTypes.TINYINT) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.TINYINT) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.TINYINT) long right, - @IsNull boolean rightNull) + public static class TinyIntIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.TINYINT) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.TINYINT) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.TINYINT, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.TINYINT, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !TINYINT.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(INDETERMINATE) diff --git a/presto-main/src/main/java/com/facebook/presto/type/VarbinaryOperators.java b/presto-main/src/main/java/com/facebook/presto/type/VarbinaryOperators.java index a89631fb045b6..aa63e26e264d4 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/VarbinaryOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/VarbinaryOperators.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.type; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; @@ -31,6 +34,7 @@ import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; public final class VarbinaryOperators { @@ -98,20 +102,39 @@ public static long hashCode(@SqlType(StandardTypes.VARBINARY) Slice value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.VARBINARY) Slice left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.VARBINARY) Slice right, - @IsNull boolean rightNull) + public static class VarbinaryIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.VARBINARY) Slice left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.VARBINARY) Slice right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.VARBINARY, nativeContainerType = Slice.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.VARBINARY, nativeContainerType = Slice.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !VARBINARY.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(XX_HASH_64) diff --git a/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java b/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java index e53eb6d12ba5f..e7536663475fa 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -35,6 +38,7 @@ import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static java.lang.String.format; public final class VarcharOperators @@ -235,22 +239,42 @@ public static long hashCode(@SqlType("varchar(x)") Slice value) return xxHash64(value); } - @LiteralParameters({"x", "y"}) @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType("varchar(x)") Slice left, - @IsNull boolean leftNull, - @SqlType("varchar(y)") Slice right, - @IsNull boolean rightNull) + public static class VarcharIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @LiteralParameters({"x", "y"}) + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType("varchar(x)") Slice left, + @IsNull boolean leftNull, + @SqlType("varchar(y)") Slice right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @LiteralParameters({"x", "y"}) + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = "varchar(x)", nativeContainerType = Slice.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = "varchar(y)", nativeContainerType = Slice.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !VARCHAR.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @LiteralParameters("x") diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java index 3536a75b4fd08..225e306d000c2 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java @@ -21,8 +21,6 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.AggregationImplementation; import com.facebook.presto.operator.aggregation.AggregationMetadata; -import com.facebook.presto.spi.function.BlockIndex; -import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.operator.aggregation.LazyAccumulatorFactoryBinder; import com.facebook.presto.operator.aggregation.ParametricAggregation; @@ -39,6 +37,8 @@ import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationState; import com.facebook.presto.spi.function.AggregationStateSerializerFactory; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CombineFunction; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.InputFunction; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java index 2d395a18056af..42bf910202df8 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java @@ -101,6 +101,7 @@ import java.util.function.Supplier; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; import static com.facebook.presto.block.BlockAssertions.createIntsBlock; @@ -117,6 +118,7 @@ import static com.facebook.presto.spi.type.DateTimeEncoding.packDateTimeWithZone; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; @@ -162,7 +164,10 @@ public final class FunctionAssertions createStringsBlock((String) null), createTimestampsWithTimezoneBlock(packDateTimeWithZone(new DateTime(1970, 1, 1, 0, 1, 0, 999, DateTimeZone.UTC).getMillis(), TimeZoneKey.getTimeZoneKey("Z"))), createSlicesBlock(Slices.wrappedBuffer((byte) 0xab)), - createIntsBlock(1234)); + createIntsBlock(1234), + createIntsBlock((Integer) null), + createBlockOfReals(56.7f), + createDoublesBlock((Double) null)); private static final Page ZERO_CHANNEL_PAGE = new Page(1); @@ -177,6 +182,9 @@ public final class FunctionAssertions .put(7, TIMESTAMP_WITH_TIME_ZONE) .put(8, VARBINARY) .put(9, INTEGER) + .put(10, INTEGER) + .put(11, REAL) + .put(12, DOUBLE) .build(); private static final Map INPUT_MAPPING = ImmutableMap.builder() @@ -190,6 +198,9 @@ public final class FunctionAssertions .put(new Symbol("bound_timestamp_with_timezone"), 7) .put(new Symbol("bound_binary_literal"), 8) .put(new Symbol("bound_integer"), 9) + .put(new Symbol("bound_null_integer"), 10) + .put(new Symbol("bound_real"), 11) + .put(new Symbol("bound_null_double"), 12) .build(); private static final TypeProvider SYMBOL_TYPES = TypeProvider.copyOf(ImmutableMap.builder() @@ -203,6 +214,8 @@ public final class FunctionAssertions .put(new Symbol("bound_timestamp_with_timezone"), TIMESTAMP_WITH_TIME_ZONE) .put(new Symbol("bound_binary_literal"), VARBINARY) .put(new Symbol("bound_integer"), INTEGER) + .put(new Symbol("bound_null_integer"), INTEGER) + .put(new Symbol("bound_null_double"), DOUBLE) .build()); private static final PageSourceProvider PAGE_SOURCE_PROVIDER = new TestPageSourceProvider(); @@ -1014,7 +1027,7 @@ public ConnectorPageSource createPageSource(Session session, Split split, List