diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java index 0f142affac350..c0e582ea2def3 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java @@ -480,15 +480,25 @@ public FunctionRegistry(TypeManager typeManager, BlockEncodingSerde blockEncodin .scalars(HyperLogLogFunctions.class) .scalars(UnknownOperators.class) .scalars(BooleanOperators.class) + .scalar(BooleanOperators.BooleanIsDistinctFrom.class) .scalars(BigintOperators.class) + .scalar(BigintOperators.BigIntIsDistinctFrom.class) .scalars(IntegerOperators.class) + .scalar(IntegerOperators.IntegerIsDistinctFrom.class) .scalars(SmallintOperators.class) + .scalar(SmallintOperators.SmallIntIsDistinctFrom.class) .scalars(TinyintOperators.class) + .scalar(TinyintOperators.TinyIntIsDistinctFrom.class) .scalars(DoubleOperators.class) + .scalar(DoubleOperators.DoubleIsDistinctFrom.class) .scalars(RealOperators.class) + .scalar(RealOperators.RealIsDistinctFrom.class) .scalars(VarcharOperators.class) + .scalar(VarcharOperators.VarcharIsDistinctFrom.class) .scalars(VarbinaryOperators.class) + .scalar(VarbinaryOperators.VarbinaryIsDistinctFrom.class) .scalars(DateOperators.class) + .scalar(DateOperators.DateIsDistinctFrom.class) .scalars(TimeOperators.class) .scalars(TimestampOperators.class) .scalars(IntervalDayTimeOperators.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java index 8f68b94935602..3cf13c853b6f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java @@ -26,6 +26,8 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.OutputFunction; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index c2fd5cb86d96b..269bdf8f3cfd2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -18,8 +18,7 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.operator.ParametricImplementationsGroup; -import com.facebook.presto.operator.scalar.annotations.ScalarImplementation; -import com.facebook.presto.operator.scalar.annotations.ScalarImplementation.MethodHandleAndConstructor; +import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.TypeManager; import com.google.common.annotations.VisibleForTesting; @@ -28,7 +27,6 @@ import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; import static com.facebook.presto.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_IMPLEMENTATION; -import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static com.facebook.presto.util.Failures.checkCondition; import static java.lang.String.format; @@ -38,12 +36,12 @@ public class ParametricScalar extends SqlScalarFunction { private final ScalarHeader details; - private final ParametricImplementationsGroup implementations; + private final ParametricImplementationsGroup implementations; public ParametricScalar( Signature signature, ScalarHeader details, - ParametricImplementationsGroup implementations) + ParametricImplementationsGroup implementations) { super(signature); this.details = requireNonNull(details); @@ -69,7 +67,7 @@ public String getDescription() } @VisibleForTesting - public ParametricImplementationsGroup getImplementations() + public ParametricImplementationsGroup getImplementations() { return implementations; } @@ -79,44 +77,27 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in { Signature boundSignature = applyBoundVariables(getSignature(), boundVariables, arity); if (implementations.getExactImplementations().containsKey(boundSignature)) { - ScalarImplementation implementation = implementations.getExactImplementations().get(boundSignature); - Optional methodHandleAndConstructor = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); - checkCondition(methodHandleAndConstructor.isPresent(), FUNCTION_IMPLEMENTATION_ERROR, String.format("Exact implementation of %s do not match expected java types.", boundSignature.getName())); - return new ScalarFunctionImplementation( - implementation.isNullable(), - implementation.getArgumentProperties(), - methodHandleAndConstructor.get().getMethodHandle(), - methodHandleAndConstructor.get().getConstructor(), - isDeterministic()); + ParametricScalarImplementation implementation = implementations.getExactImplementations().get(boundSignature); + Optional scalarFunctionImplementation = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry, isDeterministic()); + return scalarFunctionImplementation.get(); } ScalarFunctionImplementation selectedImplementation = null; - for (ScalarImplementation implementation : implementations.getSpecializedImplementations()) { - Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); - if (methodHandle.isPresent()) { + for (ParametricScalarImplementation implementation : implementations.getSpecializedImplementations()) { + Optional scalarFunctionImplementation = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry, isDeterministic()); + if (scalarFunctionImplementation.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); - selectedImplementation = new ScalarFunctionImplementation( - implementation.isNullable(), - implementation.getArgumentProperties(), - methodHandle.get().getMethodHandle(), - methodHandle.get().getConstructor(), - isDeterministic()); + selectedImplementation = scalarFunctionImplementation.get(); } } if (selectedImplementation != null) { return selectedImplementation; } - - for (ScalarImplementation implementation : implementations.getGenericImplementations()) { - Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); - if (methodHandle.isPresent()) { + for (ParametricScalarImplementation implementation : implementations.getGenericImplementations()) { + Optional scalarFunctionImplementation = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry, isDeterministic()); + if (scalarFunctionImplementation.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); - selectedImplementation = new ScalarFunctionImplementation( - implementation.isNullable(), - implementation.getArgumentProperties(), - methodHandle.get().getMethodHandle(), - methodHandle.get().getConstructor(), - isDeterministic()); + selectedImplementation = scalarFunctionImplementation.get(); } } if (selectedImplementation != null) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java similarity index 59% rename from presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java rename to presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java index 9795884977125..539b6e2b77e9e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java @@ -23,7 +23,10 @@ import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; @@ -44,9 +47,11 @@ import java.lang.reflect.Method; import java.lang.reflect.Parameter; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Stream; @@ -67,6 +72,7 @@ import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentType.FUNCTION_TYPE; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentType.VALUE_TYPE; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.BLOCK_AND_POSITION; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_BOXED_TYPE; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_NULL_FLAG; @@ -81,75 +87,79 @@ import static java.lang.invoke.MethodHandles.permuteArguments; import static java.lang.reflect.Modifier.isStatic; import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertTrue; -public class ScalarImplementation +public class ParametricScalarImplementation implements ParametricImplementation { private final Signature signature; - private final boolean nullable; - private final List argumentProperties; - private final MethodHandle methodHandle; private final List dependencies; - private final Optional constructor; private final List constructorDependencies; private final List> argumentNativeContainerTypes; private final Map> specializedTypeParameters; + private final Class returnNativeContainerType; + private final SpecializedSignature specializedSignature; + private final List choices; - public ScalarImplementation( + private ParametricScalarImplementation( Signature signature, - boolean nullable, - List argumentProperties, - MethodHandle methodHandle, List dependencies, - Optional constructor, List constructorDependencies, List> argumentNativeContainerTypes, - Map> specializedTypeParameters) + Map> specializedTypeParameters, + List choices, + Class returnContainerType, + SpecializedSignature specializedSignature) { this.signature = requireNonNull(signature, "signature is null"); - this.nullable = nullable; - this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null")); - this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null")); - this.constructor = requireNonNull(constructor, "constructor is null"); this.constructorDependencies = ImmutableList.copyOf(requireNonNull(constructorDependencies, "constructorDependencies is null")); this.argumentNativeContainerTypes = ImmutableList.copyOf(requireNonNull(argumentNativeContainerTypes, "argumentNativeContainerTypes is null")); this.specializedTypeParameters = ImmutableMap.copyOf(requireNonNull(specializedTypeParameters, "specializedTypeParameters is null")); + this.choices = choices; + this.returnNativeContainerType = returnContainerType; + this.specializedSignature = specializedSignature; } - public Optional specialize(Signature boundSignature, BoundVariables boundVariables, TypeManager typeManager, FunctionRegistry functionRegistry) + public Optional specialize(Signature boundSignature, BoundVariables boundVariables, TypeManager typeManager, FunctionRegistry functionRegistry, boolean isDeterministic) { + List implementationChoices = new ArrayList(); for (Map.Entry> entry : specializedTypeParameters.entrySet()) { if (!entry.getValue().isAssignableFrom(boundVariables.getTypeVariable(entry.getKey()).getJavaType())) { return Optional.empty(); } } - Class returnContainerType = getNullAwareReturnType(typeManager.getType(boundSignature.getReturnType()).getJavaType(), nullable); - if (!returnContainerType.equals(methodHandle.type().returnType())) { + + if (!returnNativeContainerType.equals(typeManager.getType(boundSignature.getReturnType()).getJavaType())) { return Optional.empty(); } - for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) { - ScalarFunctionImplementation.ArgumentProperty argumentProperty = argumentProperties.get(i); - if (boundSignature.getArgumentTypes().get(i).getBase().equals(FunctionType.NAME)) { - if (argumentProperty.getArgumentType() != FUNCTION_TYPE) { - return Optional.empty(); - } - } - else { - if (argumentProperty.getArgumentType() != VALUE_TYPE) { - return Optional.empty(); + for (ParametricScalarImplementationChoice choice : choices) { + for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) { + ScalarFunctionImplementation.ArgumentProperty argumentProperty = choice.getArgumentProperties().get(i); + if (boundSignature.getArgumentTypes().get(i).getBase().equals(FunctionType.NAME)) { + if (argumentProperty.getArgumentType() != FUNCTION_TYPE) { + return Optional.empty(); + } } + else { + if (argumentProperty.getArgumentType() != VALUE_TYPE) { + return Optional.empty(); + } - Class argumentType = typeManager.getType(boundSignature.getArgumentTypes().get(i)).getJavaType(); - Class argumentContainerType = getNullAwareContainerType(argumentType, argumentProperty.getNullConvention()); - if (!argumentNativeContainerTypes.get(i).isAssignableFrom(argumentContainerType)) { - return Optional.empty(); + Class argumentType = typeManager.getType(boundSignature.getArgumentTypes().get(i)).getJavaType(); + Class argumentContainerType = getNullAwareContainerType(argumentType, argumentProperty.getNullConvention()); + + if (!choice.getArgumentNativeContainerTypes().get(i).isAssignableFrom(argumentContainerType)) { + return Optional.empty(); + } } } + MethodHandle boundMethodHandle = bindDependencies(choice.methodHandle, dependencies, boundVariables, typeManager, functionRegistry); + Optional boundConstructor = choice.constructor.map(handle -> bindDependencies(handle, constructorDependencies, boundVariables, typeManager, functionRegistry)); + + implementationChoices.add(new ScalarImplementationChoice(choice.nullable, choice.argumentProperties, boundMethodHandle, boundConstructor)); } - MethodHandle boundMethodHandle = bindDependencies(this.methodHandle, dependencies, boundVariables, typeManager, functionRegistry); - Optional boundConstructor = this.constructor.map(handle -> bindDependencies(handle, constructorDependencies, boundVariables, typeManager, functionRegistry)); - return Optional.of(new MethodHandleAndConstructor(boundMethodHandle, boundConstructor)); + return Optional.of(new ScalarFunctionImplementation(implementationChoices, isDeterministic)); } private static Class getNullAwareReturnType(Class clazz, boolean nullable) @@ -183,30 +193,87 @@ public Signature getSignature() return signature; } - public boolean isNullable() + public List getDependencies() { - return nullable; + return dependencies; } - public List getArgumentProperties() + @VisibleForTesting + public List getConstructorDependencies() { - return argumentProperties; + return constructorDependencies; } - public MethodHandle getMethodHandle() + public void updateChoices(ParametricScalarImplementation implementation) { - return methodHandle; + for (ParametricScalarImplementationChoice choice : implementation.choices) { + this.choices.add(choice); + } + Collections.sort(choices, ParametricScalarImplementationChoice::compareTo); } - public List getDependencies() + public SpecializedSignature getSpecializedSignature() { - return dependencies; + return specializedSignature; } - @VisibleForTesting - public List getConstructorDependencies() + public static final class ParametricScalarImplementationChoice + implements Comparable { - return constructorDependencies; + private final boolean nullable; + private final List argumentProperties; + private final MethodHandle methodHandle; + private final Optional constructor; + private final Map> specializedTypeParameters; + private final List> argumentNativeContainerTypes; + private final int numberOfBlockPositionArguments; + + private ParametricScalarImplementationChoice( + boolean nullable, + List argumentProperties, + MethodHandle methodHandle, + Optional constructor, + Map> specializedTypeParameters, + List> argumentNativeContainerTypes, + int numberOfBlockPositionArguments) + { + this.nullable = nullable; + this.argumentProperties = argumentProperties; + this.methodHandle = methodHandle; + this.constructor = requireNonNull(constructor, "constructor is null"); + this.specializedTypeParameters = specializedTypeParameters; + this.argumentNativeContainerTypes = argumentNativeContainerTypes; + this.numberOfBlockPositionArguments = numberOfBlockPositionArguments; + } + + public boolean isNullable() + { + return nullable; + } + + public List getArgumentProperties() + { + return argumentProperties; + } + + public MethodHandle getMethodHandle() + { + return methodHandle; + } + + public List> getArgumentNativeContainerTypes() + { + return argumentNativeContainerTypes; + } + + @Override + public int compareTo(ParametricScalarImplementationChoice choice) + { + if (choice.numberOfBlockPositionArguments < this.numberOfBlockPositionArguments) { + return 1; + } + return -1; + } } public static final class MethodHandleAndConstructor @@ -231,6 +298,48 @@ public Optional getConstructor() } } + public static final class SpecializedSignature + { + private final Signature signature; + private final List> argumentNativeContainerTypes; + private final Map> specializedTypeParameters; + private final Class returnNativeContainerType; + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SpecializedSignature that = (SpecializedSignature) o; + return Objects.equals(signature, that.signature) && + Objects.equals(argumentNativeContainerTypes, that.argumentNativeContainerTypes) && + Objects.equals(specializedTypeParameters, that.specializedTypeParameters) && + Objects.equals(returnNativeContainerType, that.returnNativeContainerType); + } + + @Override + public int hashCode() + { + return Objects.hash(signature, argumentNativeContainerTypes, specializedTypeParameters, returnNativeContainerType); + } + + private SpecializedSignature( + Signature signature, + List> argumentNativeContainerTypes, + Map> specializedTypeParameters, + Class returnNativeContainerType) + { + this.signature = signature; + this.argumentNativeContainerTypes = argumentNativeContainerTypes; + this.specializedTypeParameters = specializedTypeParameters; + this.returnNativeContainerType = returnNativeContainerType; + } + } + public static final class Parser { private final String functionName; @@ -248,11 +357,17 @@ public static final class Parser private final Optional constructorMethodHandle; private final List constructorDependencies = new ArrayList<>(); private final List longVariableConstraints; + private final Class returnNativeContainerType; + private NullConventionFlag nullConventionFlag; + private int numberOfBlockPositionArguments; + + private final List choices = new ArrayList(); private Parser(String functionName, Method method, Optional> constructor) { this.functionName = requireNonNull(functionName, "functionName is null"); this.nullable = method.getAnnotation(SqlNullable.class) != null; + ParametricScalarImplementation.NullConventionFlag nullConventionFlag = nullable ? NullConventionFlag.NULLABLE_ARGUMENT : NullConventionFlag.NOT_NULLABLE_ARGUMENT; checkArgument(nullable || !containsLegacyNullable(method.getAnnotations()), "Method [%s] is annotated with @Nullable but not @SqlNullable", method); Stream.of(method.getAnnotationsByType(TypeParameter.class)) @@ -268,6 +383,8 @@ private Parser(String functionName, Method method, Optional> cons this.returnType = parseTypeSignature(returnType.value(), literalParameters); Class actualReturnType = method.getReturnType(); + this.returnNativeContainerType = Primitives.unwrap(actualReturnType); + if (Primitives.isWrapperType(actualReturnType)) { checkArgument(nullable, "Method [%s] has wrapper return type %s but is missing @SqlNullable", method, actualReturnType.getSimpleName()); } @@ -285,11 +402,15 @@ else if (actualReturnType.isPrimitive()) { "Expected type parameter to only contain A-Z and 0-9 (starting with A-Z), but got %s on method [%s]", typeParameter.value(), method); } + inferSpecialization(method, actualReturnType, returnType.value(), nullConventionFlag); parseArguments(method); this.constructorMethodHandle = getConstructor(method, constructor); this.methodHandle = getMethodHandle(method); + + ParametricScalarImplementationChoice choice = new ParametricScalarImplementationChoice(nullable, argumentProperties, methodHandle, constructorMethodHandle, specializedTypeParameters, argumentNativeContainerTypes, numberOfBlockPositionArguments); + choices.add(choice); } private void parseArguments(Method method) @@ -319,10 +440,22 @@ private void parseArguments(Method method) .findFirst() .orElseThrow(() -> new IllegalArgumentException(format("Method [%s] is missing @SqlType annotation for parameter", method))); TypeSignature typeSignature = parseTypeSignature(type.value(), literalParameters); - boolean nullableArgument = Stream.of(annotations).anyMatch(SqlNullable.class::isInstance); - checkArgument(nullableArgument || !containsLegacyNullable(annotations), "Method [%s] has parameter annotated with @Nullable but not @SqlNullable", method); + if (Stream.of(annotations).anyMatch(SqlNullable.class::isInstance)) { + nullConventionFlag = NullConventionFlag.NULLABLE_ARGUMENT; + } + else { + nullConventionFlag = NullConventionFlag.NOT_NULLABLE_ARGUMENT; + } + checkArgument(nullConventionFlag == NullConventionFlag.NULLABLE_ARGUMENT || !containsLegacyNullable(annotations), "Method [%s] has parameter annotated with @Nullable but not @SqlNullable", method); + + if (Stream.of(annotations).anyMatch(BlockPosition.class::isInstance)) { + nullConventionFlag = NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION; + assertTrue(method.getParameterCount() > (i + 1)); + Annotation[] parameterAnnotations = method.getParameterAnnotations()[i + 1]; + assertTrue(Stream.of(parameterAnnotations).anyMatch(BlockIndex.class::isInstance)); + numberOfBlockPositionArguments++; + } - boolean hasNullFlag = false; if (method.getParameterCount() > (i + 1)) { Annotation[] parameterAnnotations = method.getParameterAnnotations()[i + 1]; if (Stream.of(parameterAnnotations).anyMatch(IsNull.class::isInstance)) { @@ -332,30 +465,30 @@ private void parseArguments(Method method) checkArgument(isNullType == boolean.class, "Method [%s] has non-boolean parameter with @IsNull", method); checkArgument((parameterType == Void.class) || !Primitives.isWrapperType(parameterType), "Method [%s] uses @IsNull following a parameter with boxed primitive type: %s", method, parameterType.getSimpleName()); - nullableArgument = true; - hasNullFlag = true; + nullConventionFlag = NullConventionFlag.HAS_NULL_FLAG; } } if (Primitives.isWrapperType(parameterType)) { - checkArgument(nullableArgument, "Method [%s] has parameter with wrapper type %s that is missing @SqlNullable", method, parameterType.getSimpleName()); + checkArgument(nullConventionFlag != NullConventionFlag.NOT_NULLABLE_ARGUMENT, "Method [%s] has parameter with wrapper type %s that is missing @SqlNullable", method, parameterType.getSimpleName()); } - else if (parameterType.isPrimitive() && !hasNullFlag) { - checkArgument(!nullableArgument, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName()); + else if (parameterType.isPrimitive() && nullConventionFlag != NullConventionFlag.HAS_NULL_FLAG) { + checkArgument(nullConventionFlag == NullConventionFlag.NOT_NULLABLE_ARGUMENT, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName()); } - if (typeParameterNames.contains(type.value()) && !(parameterType == Object.class && nullableArgument)) { - // Infer specialization on this type parameter. We don't do this for @SqlNullable Object because it could match a type like BIGINT - Class specialization = specializedTypeParameters.get(type.value()); - Class nativeParameterType = Primitives.unwrap(parameterType); - checkArgument(specialization == null || specialization.equals(nativeParameterType), "Method [%s] type %s has conflicting specializations %s and %s", method, type.value(), specialization, nativeParameterType); - specializedTypeParameters.put(type.value(), nativeParameterType); + inferSpecialization(method, parameterType, type.value(), nullConventionFlag); + + if (nullConventionFlag == NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION) { + argumentNativeContainerTypes.add(type.nativeContainerType()); + } + else { + checkArgument(type.nativeContainerType().equals(Object.class), "Native container type for parameter needs to be an Object"); + argumentNativeContainerTypes.add(parameterType); } - argumentNativeContainerTypes.add(parameterType); argumentTypes.add(typeSignature); - if (hasNullFlag) { - // skip @IsNull parameter + if (nullConventionFlag == NullConventionFlag.HAS_NULL_FLAG || nullConventionFlag == NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION) { + // skip @IsNull and @BlockIndex parameter i++; } @@ -365,11 +498,14 @@ else if (parameterType.isPrimitive() && !hasNullFlag) { } else { NullConvention nullConvention; - if (!nullableArgument) { + if (nullConventionFlag == NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION) { + nullConvention = BLOCK_AND_POSITION; + } + else if (nullConventionFlag == NullConventionFlag.NOT_NULLABLE_ARGUMENT) { nullConvention = RETURN_NULL_ON_NULL; } else { - nullConvention = hasNullFlag ? USE_NULL_FLAG : USE_BOXED_TYPE; + nullConvention = nullConventionFlag == NullConventionFlag.HAS_NULL_FLAG ? USE_NULL_FLAG : USE_BOXED_TYPE; } argumentProperties.add(valueTypeArgumentProperty(nullConvention)); } @@ -377,6 +513,17 @@ else if (parameterType.isPrimitive() && !hasNullFlag) { } } + private void inferSpecialization(Method method, Class parameterType, String typeParameterName, NullConventionFlag nullableArgument) + { + if (typeParameterNames.contains(typeParameterName) && !(parameterType == Object.class && nullableArgument == NullConventionFlag.NULLABLE_ARGUMENT) && nullableArgument != NullConventionFlag.IS_BLOCK_AND_POSITION_CONVENTION) { + // Infer specialization on this type parameter. We don't do this for @SqlNullable Object because it could match a type like BIGINT + Class specialization = specializedTypeParameters.get(typeParameterName); + Class nativeParameterType = Primitives.unwrap(parameterType); + checkArgument(specialization == null || specialization.equals(nativeParameterType), "Method [%s] type %s has conflicting specializations %s and %s", method, typeParameterName, specialization, nativeParameterType); + specializedTypeParameters.put(typeParameterName, nativeParameterType); + } + } + // Find matching constructor, if this is an instance method, and populate constructorDependencies private Optional getConstructor(Method method, Optional> optionalConstructor) { @@ -427,7 +574,7 @@ private MethodHandle getMethodHandle(Method method) return methodHandle; } - public ScalarImplementation get() + public ParametricScalarImplementation get() { Signature signature = new Signature( functionName, @@ -437,21 +584,35 @@ public ScalarImplementation get() returnType, argumentTypes, false); - return new ScalarImplementation( + + SpecializedSignature specializedSignature = new SpecializedSignature( + signature, + argumentNativeContainerTypes, + specializedTypeParameters, + returnNativeContainerType); + + return new ParametricScalarImplementation( signature, - nullable, - argumentProperties, - methodHandle, dependencies, - constructorMethodHandle, constructorDependencies, argumentNativeContainerTypes, - specializedTypeParameters); + specializedTypeParameters, + choices, + returnNativeContainerType, + specializedSignature); } - public static ScalarImplementation parseImplementation(String functionName, Method method, Optional> constructor) + public static ParametricScalarImplementation parseImplementation(String functionName, Method method, Optional> constructor) { return new Parser(functionName, method, constructor).get(); } } + + public enum NullConventionFlag + { + NULLABLE_ARGUMENT, + HAS_NULL_FLAG, + NOT_NULLABLE_ARGUMENT, + IS_BLOCK_AND_POSITION_CONVENTION; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java index 92d9dbb9e06a5..1e6ea8b1e2ecc 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java @@ -18,6 +18,7 @@ import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.annotations.FunctionsParserHelper; import com.facebook.presto.operator.scalar.ParametricScalar; +import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation.SpecializedSignature; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; @@ -26,7 +27,9 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -93,16 +96,24 @@ private static List findScalarsInFunctionSetClass(Class< private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods scalar, Optional> constructor) { - ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); + ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); ScalarImplementationHeader header = scalar.getHeader(); checkArgument(!header.getName().isEmpty()); + Map signatures = new HashMap(); for (Method method : scalar.getMethods()) { - ScalarImplementation implementation = ScalarImplementation.Parser.parseImplementation(header.getName(), method, constructor); - implementationsBuilder.addImplementation(implementation); + ParametricScalarImplementation implementation = ParametricScalarImplementation.Parser.parseImplementation(header.getName(), method, constructor); + if (!signatures.containsKey(implementation.getSpecializedSignature())) { + signatures.put(implementation.getSpecializedSignature(), implementation); + implementationsBuilder.addImplementation(implementation); + } + else { + ParametricScalarImplementation currentImplementation = signatures.get(implementation.getSpecializedSignature()); + currentImplementation.updateChoices(implementation); + } } - ParametricImplementationsGroup implementations = implementationsBuilder.build(); + ParametricImplementationsGroup implementations = implementationsBuilder.build(); Signature scalarSignature = implementations.getSignature(); header.getOperatorType().ifPresent(operatorType -> diff --git a/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java b/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java index 63132870586fc..a1a98140d164e 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -47,6 +50,7 @@ import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToRawIntBits; import static java.lang.Math.toIntExact; @@ -277,20 +281,39 @@ public static long hashCode(@SqlType(StandardTypes.BIGINT) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.BIGINT) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.BIGINT) long right, - @IsNull boolean rightNull) + public static class BigIntIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.BIGINT) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.BIGINT) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !BIGINT.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(XX_HASH_64) diff --git a/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java b/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java index 7f1dced74a8a0..cd1df1b271118 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.type; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; @@ -33,6 +36,7 @@ import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static java.lang.Float.floatToRawIntBits; import static java.nio.charset.StandardCharsets.US_ASCII; @@ -166,19 +170,38 @@ public static boolean not(@SqlType(StandardTypes.BOOLEAN) boolean value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.BOOLEAN) boolean left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.BOOLEAN) boolean right, - @IsNull boolean rightNull) - { - if (leftNull != rightNull) { - return true; + public static class BooleanIsDistinctFrom + { + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.BOOLEAN) boolean left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.BOOLEAN) boolean right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.BOOLEAN, nativeContainerType = boolean.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.BOOLEAN, nativeContainerType = boolean.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !BOOLEAN.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/DateOperators.java b/presto-main/src/main/java/com/facebook/presto/type/DateOperators.java index ba296079d5507..e5e46d5b78725 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/DateOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/DateOperators.java @@ -163,20 +163,41 @@ public static long hashCode(@SqlType(StandardTypes.DATE) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.DATE) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.DATE) long right, - @IsNull boolean rightNull) - { - if (leftNull != rightNull) { - return true; + public static class DateIsDistinctFrom + { + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.DATE) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.DATE) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + /* + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(StandardTypes.DATE) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(StandardTypes.DATE) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return DATE.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); + */ } @ScalarOperator(INDETERMINATE) diff --git a/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java b/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java index 3223ae8bf881c..4212d82e23226 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java @@ -15,6 +15,9 @@ import com.facebook.presto.operator.scalar.MathFunctions; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -49,6 +52,7 @@ import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.google.common.base.Preconditions.checkState; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Double.doubleToLongBits; @@ -315,23 +319,45 @@ private static long saturatedFloorCastToLong(double value, long minValue, double } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.DOUBLE) double left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.DOUBLE) double right, - @IsNull boolean rightNull) + public static class DoubleIsDistinctFrom { - if (leftNull != rightNull) { - return true; - } - if (leftNull) { - return false; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.DOUBLE) double left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.DOUBLE) double right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + if (Double.isNaN(left) && Double.isNaN(right)) { + return false; + } + return notEqual(left, right); } - if (Double.isNaN(left) && Double.isNaN(right)) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + if (Double.isNaN(DOUBLE.getDouble(left, leftPosition)) && Double.isNaN(DOUBLE.getDouble(right, rightPosition))) { + return false; + } + return !DOUBLE.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(XX_HASH_64) diff --git a/presto-main/src/main/java/com/facebook/presto/type/IntegerOperators.java b/presto-main/src/main/java/com/facebook/presto/type/IntegerOperators.java index da41e0a7c6232..4d50cedf03428 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/IntegerOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/IntegerOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -43,9 +46,9 @@ import static com.facebook.presto.spi.function.OperatorType.MULTIPLY; import static com.facebook.presto.spi.function.OperatorType.NEGATION; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; -import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToRawIntBits; import static java.lang.String.format; @@ -246,34 +249,39 @@ public static long hashCode(@SqlType(StandardTypes.INTEGER) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.INTEGER) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.INTEGER) long right, - @IsNull boolean rightNull) + public static class IntegerIsDistinctFrom { - if (leftNull != rightNull) { - return true; - } - if (leftNull) { - return false; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.INTEGER) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.INTEGER) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - return notEqual(left, right); - } - - @ScalarOperator(SATURATED_FLOOR_CAST) - @SqlType(StandardTypes.SMALLINT) - public static long saturatedFloorCastToSmallint(@SqlType(StandardTypes.INTEGER) long value) - { - return Shorts.saturatedCast(value); - } - @ScalarOperator(SATURATED_FLOOR_CAST) - @SqlType(StandardTypes.TINYINT) - public static long saturatedFloorCastToTinyint(@SqlType(StandardTypes.INTEGER) long value) - { - return SignedBytes.saturatedCast(value); + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.INTEGER, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.INTEGER, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !INTEGER.equalTo(left, leftPosition, right, rightPosition); + } } @ScalarOperator(INDETERMINATE) diff --git a/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java b/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java index dc05445801701..45aa31f23db9c 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java @@ -15,6 +15,9 @@ import com.facebook.presto.operator.scalar.MathFunctions; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -47,6 +50,7 @@ import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.RealType.REAL; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToIntBits; import static java.lang.Float.floatToRawIntBits; @@ -241,25 +245,44 @@ public static boolean castToBoolean(@SqlType(StandardTypes.REAL) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.REAL) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.REAL) long right, - @IsNull boolean rightNull) + public static class RealIsDistinctFrom { - if (leftNull != rightNull) { - return true; - } - if (leftNull) { - return false; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.REAL) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.REAL) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + float leftFloat = intBitsToFloat((int) left); + float rightFloat = intBitsToFloat((int) right); + if (Float.isNaN(leftFloat) && Float.isNaN(rightFloat)) { + return false; + } + return notEqual(left, right); } - float leftFloat = intBitsToFloat((int) left); - float rightFloat = intBitsToFloat((int) right); - if (Float.isNaN(leftFloat) && Float.isNaN(rightFloat)) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.REAL, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.REAL, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return REAL.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(SATURATED_FLOOR_CAST) diff --git a/presto-main/src/main/java/com/facebook/presto/type/SmallintOperators.java b/presto-main/src/main/java/com/facebook/presto/type/SmallintOperators.java index 882653c20c5f6..83c0655b7b84f 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/SmallintOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/SmallintOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -46,6 +49,7 @@ import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.SmallintType.SMALLINT; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToRawIntBits; import static java.lang.String.format; @@ -241,20 +245,39 @@ public static long hashCode(@SqlType(StandardTypes.SMALLINT) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.SMALLINT) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.SMALLINT) long right, - @IsNull boolean rightNull) + public static class SmallIntIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.SMALLINT) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.SMALLINT) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.SMALLINT, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.SMALLINT, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !SMALLINT.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(SATURATED_FLOOR_CAST) diff --git a/presto-main/src/main/java/com/facebook/presto/type/TinyintOperators.java b/presto-main/src/main/java/com/facebook/presto/type/TinyintOperators.java index b2ac0d7cf082a..26f343bcc724d 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/TinyintOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/TinyintOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -44,6 +47,7 @@ import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToRawIntBits; import static java.lang.String.format; @@ -241,20 +245,39 @@ public static long xxHash64(@SqlType(StandardTypes.TINYINT) long value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.TINYINT) long left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.TINYINT) long right, - @IsNull boolean rightNull) + public static class TinyIntIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.TINYINT) long left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.TINYINT) long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.TINYINT, nativeContainerType = long.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.TINYINT, nativeContainerType = long.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !TINYINT.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(INDETERMINATE) diff --git a/presto-main/src/main/java/com/facebook/presto/type/VarbinaryOperators.java b/presto-main/src/main/java/com/facebook/presto/type/VarbinaryOperators.java index a89631fb045b6..aa63e26e264d4 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/VarbinaryOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/VarbinaryOperators.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.type; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; @@ -31,6 +34,7 @@ import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; public final class VarbinaryOperators { @@ -98,20 +102,39 @@ public static long hashCode(@SqlType(StandardTypes.VARBINARY) Slice value) } @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType(StandardTypes.VARBINARY) Slice left, - @IsNull boolean leftNull, - @SqlType(StandardTypes.VARBINARY) Slice right, - @IsNull boolean rightNull) + public static class VarbinaryIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType(StandardTypes.VARBINARY) Slice left, + @IsNull boolean leftNull, + @SqlType(StandardTypes.VARBINARY) Slice right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = StandardTypes.VARBINARY, nativeContainerType = Slice.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = StandardTypes.VARBINARY, nativeContainerType = Slice.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !VARBINARY.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @ScalarOperator(XX_HASH_64) diff --git a/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java b/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java index e53eb6d12ba5f..e7536663475fa 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java @@ -14,6 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; @@ -35,6 +38,7 @@ import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.spi.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static java.lang.String.format; public final class VarcharOperators @@ -235,22 +239,42 @@ public static long hashCode(@SqlType("varchar(x)") Slice value) return xxHash64(value); } - @LiteralParameters({"x", "y"}) @ScalarOperator(IS_DISTINCT_FROM) - @SqlType(StandardTypes.BOOLEAN) - public static boolean isDistinctFrom( - @SqlType("varchar(x)") Slice left, - @IsNull boolean leftNull, - @SqlType("varchar(y)") Slice right, - @IsNull boolean rightNull) + public static class VarcharIsDistinctFrom { - if (leftNull != rightNull) { - return true; + @LiteralParameters({"x", "y"}) + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @SqlType("varchar(x)") Slice left, + @IsNull boolean leftNull, + @SqlType("varchar(y)") Slice right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); } - if (leftNull) { - return false; + + @LiteralParameters({"x", "y"}) + @SqlType(StandardTypes.BOOLEAN) + public static boolean isDistinctFrom( + @BlockPosition @SqlType(value = "varchar(x)", nativeContainerType = Slice.class) Block left, + @BlockIndex int leftPosition, + @BlockPosition @SqlType(value = "varchar(y)", nativeContainerType = Slice.class) Block right, + @BlockIndex int rightPosition) + { + if (left.isNull(leftPosition) && right.isNull(rightPosition)) { + return false; + } + if (left.isNull(leftPosition)) { + return false; + } + return !VARCHAR.equalTo(left, leftPosition, right, rightPosition); } - return notEqual(left, right); } @LiteralParameters("x") diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java index 030ab612e6614..225e306d000c2 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java @@ -21,8 +21,6 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.AggregationImplementation; import com.facebook.presto.operator.aggregation.AggregationMetadata; -import com.facebook.presto.operator.aggregation.BlockIndex; -import com.facebook.presto.operator.aggregation.BlockPosition; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.operator.aggregation.LazyAccumulatorFactoryBinder; import com.facebook.presto.operator.aggregation.ParametricAggregation; @@ -39,6 +37,8 @@ import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationState; import com.facebook.presto.spi.function.AggregationStateSerializerFactory; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CombineFunction; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.InputFunction; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java index 23947fdef2e7f..b710e282a1e90 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestCountNullAggregation.java @@ -19,6 +19,8 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CombineFunction; import com.facebook.presto.spi.function.InputFunction; import com.facebook.presto.spi.function.OutputFunction; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java index 2d395a18056af..42bf910202df8 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java @@ -101,6 +101,7 @@ import java.util.function.Supplier; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; import static com.facebook.presto.block.BlockAssertions.createIntsBlock; @@ -117,6 +118,7 @@ import static com.facebook.presto.spi.type.DateTimeEncoding.packDateTimeWithZone; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; @@ -162,7 +164,10 @@ public final class FunctionAssertions createStringsBlock((String) null), createTimestampsWithTimezoneBlock(packDateTimeWithZone(new DateTime(1970, 1, 1, 0, 1, 0, 999, DateTimeZone.UTC).getMillis(), TimeZoneKey.getTimeZoneKey("Z"))), createSlicesBlock(Slices.wrappedBuffer((byte) 0xab)), - createIntsBlock(1234)); + createIntsBlock(1234), + createIntsBlock((Integer) null), + createBlockOfReals(56.7f), + createDoublesBlock((Double) null)); private static final Page ZERO_CHANNEL_PAGE = new Page(1); @@ -177,6 +182,9 @@ public final class FunctionAssertions .put(7, TIMESTAMP_WITH_TIME_ZONE) .put(8, VARBINARY) .put(9, INTEGER) + .put(10, INTEGER) + .put(11, REAL) + .put(12, DOUBLE) .build(); private static final Map INPUT_MAPPING = ImmutableMap.builder() @@ -190,6 +198,9 @@ public final class FunctionAssertions .put(new Symbol("bound_timestamp_with_timezone"), 7) .put(new Symbol("bound_binary_literal"), 8) .put(new Symbol("bound_integer"), 9) + .put(new Symbol("bound_null_integer"), 10) + .put(new Symbol("bound_real"), 11) + .put(new Symbol("bound_null_double"), 12) .build(); private static final TypeProvider SYMBOL_TYPES = TypeProvider.copyOf(ImmutableMap.builder() @@ -203,6 +214,8 @@ public final class FunctionAssertions .put(new Symbol("bound_timestamp_with_timezone"), TIMESTAMP_WITH_TIME_ZONE) .put(new Symbol("bound_binary_literal"), VARBINARY) .put(new Symbol("bound_integer"), INTEGER) + .put(new Symbol("bound_null_integer"), INTEGER) + .put(new Symbol("bound_null_double"), DOUBLE) .build()); private static final PageSourceProvider PAGE_SOURCE_PROVIDER = new TestPageSourceProvider(); @@ -1014,7 +1027,7 @@ public ConnectorPageSource createPageSource(Session session, Split split, List nativeContainerType() default Object.class; }