diff --git a/presto-main/src/main/java/io/prestosql/metadata/PolymorphicScalarFunction.java b/presto-main/src/main/java/io/prestosql/metadata/PolymorphicScalarFunction.java index 63e3e413b2b5..146221d8b438 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/PolymorphicScalarFunction.java +++ b/presto-main/src/main/java/io/prestosql/metadata/PolymorphicScalarFunction.java @@ -21,6 +21,7 @@ import io.prestosql.operator.scalar.ScalarFunctionImplementation; import io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention; +import io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; import io.prestosql.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeManager; @@ -124,7 +125,7 @@ private ScalarImplementationChoice getScalarFunctionImplementationChoice( List extraParameters = computeExtraParameters(matchingMethodsGroup.get(), context); MethodHandle methodHandle = applyExtraParameters(matchingMethod.get().getMethod(), extraParameters, choice.getArgumentProperties()); - return new ScalarImplementationChoice(choice.isNullableResult(), choice.getArgumentProperties(), methodHandle, Optional.empty()); + return new ScalarImplementationChoice(choice.isNullableResult(), choice.getArgumentProperties(), choice.getReturnPlaceConvention(), methodHandle, Optional.empty()); } private static boolean matchesParameterAndReturnTypes( @@ -218,15 +219,18 @@ static final class PolymorphicScalarFunctionChoice { private final boolean nullableResult; private final List argumentProperties; + private final ReturnPlaceConvention returnPlaceConvention; private final List methodsGroups; PolymorphicScalarFunctionChoice( boolean nullableResult, List argumentProperties, + ReturnPlaceConvention returnPlaceConvention, List methodsGroups) { this.nullableResult = nullableResult; this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null")); + this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null"); this.methodsGroups = ImmutableList.copyOf(requireNonNull(methodsGroups, "methodsWithExtraParametersFunctions is null")); } @@ -244,5 +248,10 @@ List getArgumentProperties() { return argumentProperties; } + + ReturnPlaceConvention getReturnPlaceConvention() + { + return returnPlaceConvention; + } } } diff --git a/presto-main/src/main/java/io/prestosql/metadata/PolymorphicScalarFunctionBuilder.java b/presto-main/src/main/java/io/prestosql/metadata/PolymorphicScalarFunctionBuilder.java index 3ea52a930c57..d12f21eb2038 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/PolymorphicScalarFunctionBuilder.java +++ b/presto-main/src/main/java/io/prestosql/metadata/PolymorphicScalarFunctionBuilder.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import io.prestosql.metadata.PolymorphicScalarFunction.PolymorphicScalarFunctionChoice; import io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; +import io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; import io.prestosql.spi.function.OperatorType; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeManager; @@ -250,6 +251,7 @@ public static class ChoiceBuilder private final Signature signature; private boolean nullableResult; private List argumentProperties; + private ReturnPlaceConvention returnPlaceConvention; private final ImmutableList.Builder methodsGroups = ImmutableList.builder(); private ChoiceBuilder(Class clazz, Signature signature) @@ -264,6 +266,10 @@ public ChoiceBuilder implementation(Function instanceFactory, boolean deterministic) { - this(ImmutableList.of(new ScalarImplementationChoice(nullable, argumentProperties, methodHandle, instanceFactory)), deterministic); + this( + ImmutableList.of(new ScalarImplementationChoice( + nullable, + argumentProperties, + ReturnPlaceConvention.STACK, + methodHandle, + instanceFactory)), + deterministic); } /** @@ -107,6 +114,7 @@ public static class ScalarImplementationChoice { private final boolean nullable; private final List argumentProperties; + private final ReturnPlaceConvention returnPlaceConvention; private final MethodHandle methodHandle; private final Optional instanceFactory; private final boolean hasSession; @@ -114,11 +122,13 @@ public static class ScalarImplementationChoice public ScalarImplementationChoice( boolean nullable, List argumentProperties, + ReturnPlaceConvention returnPlaceConvention, MethodHandle methodHandle, Optional instanceFactory) { this.nullable = nullable; this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null")); + this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null"); this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null"); @@ -158,6 +168,11 @@ public ArgumentProperty getArgumentProperty(int argumentIndex) return argumentProperties.get(argumentIndex); } + public ReturnPlaceConvention getReturnPlaceConvention() + { + return returnPlaceConvention; + } + public MethodHandle getMethodHandle() { return methodHandle; @@ -279,4 +294,10 @@ public enum ArgumentType VALUE_TYPE, FUNCTION_TYPE } + + public enum ReturnPlaceConvention + { + STACK, + PROVIDED_BLOCKBUILDER + } } diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/annotations/ParametricScalarImplementation.java b/presto-main/src/main/java/io/prestosql/operator/scalar/annotations/ParametricScalarImplementation.java index e189673ac51d..ae477a8b7d80 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/annotations/ParametricScalarImplementation.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/annotations/ParametricScalarImplementation.java @@ -28,6 +28,7 @@ import io.prestosql.operator.scalar.ScalarFunctionImplementation; import io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention; +import io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; import io.prestosql.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import io.prestosql.spi.block.Block; import io.prestosql.spi.connector.ConnectorSession; @@ -164,6 +165,7 @@ public Optional specialize(Signature boundSignatur implementationChoices.add(new ScalarImplementationChoice( choice.nullable, choice.argumentProperties, + choice.returnPlaceConvention, boundMethodHandle.asType(javaMethodType(choice, boundSignature, typeManager)), boundConstructor)); } @@ -305,6 +307,7 @@ public static final class ParametricScalarImplementationChoice { private final boolean nullable; private final List argumentProperties; + private final ReturnPlaceConvention returnPlaceConvention; private final MethodHandle methodHandle; private final Optional constructor; private final List dependencies; @@ -316,6 +319,7 @@ private ParametricScalarImplementationChoice( boolean nullable, boolean hasConnectorSession, List argumentProperties, + ReturnPlaceConvention returnPlaceConvention, MethodHandle methodHandle, Optional constructor, List dependencies, @@ -324,6 +328,7 @@ private ParametricScalarImplementationChoice( this.nullable = nullable; this.hasConnectorSession = hasConnectorSession; this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null")); + this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null"); this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.constructor = requireNonNull(constructor, "constructor is null"); this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null")); @@ -364,6 +369,11 @@ public List getArgumentProperties() return argumentProperties; } + public ReturnPlaceConvention getReturnPlaceConvention() + { + return returnPlaceConvention; + } + public boolean checkDependencies() { for (int i = 1; i < getDependencies().size(); i++) { @@ -503,7 +513,15 @@ else if (actualReturnType.isPrimitive()) { this.methodHandle = getMethodHandle(method); - ParametricScalarImplementationChoice choice = new ParametricScalarImplementationChoice(nullable, hasConnectorSession, argumentProperties, methodHandle, constructorMethodHandle, dependencies, constructorDependencies); + ParametricScalarImplementationChoice choice = new ParametricScalarImplementationChoice( + nullable, + hasConnectorSession, + argumentProperties, + ReturnPlaceConvention.STACK, // TODO: support other return place convention + methodHandle, + constructorMethodHandle, + dependencies, + constructorDependencies); choices.add(choice); } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/AndCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/AndCodeGenerator.java index 835a2ff53d83..7d71f5dd8c79 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/AndCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/AndCodeGenerator.java @@ -24,14 +24,16 @@ import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; public class AndCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments, Optional outputBlockVariable) { Preconditions.checkArgument(arguments.size() == 2); @@ -40,8 +42,8 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .comment("AND") .setDescription("AND"); - BytecodeNode left = generator.generate(arguments.get(0)); - BytecodeNode right = generator.generate(arguments.get(1)); + BytecodeNode left = generator.generate(arguments.get(0), Optional.empty()); + BytecodeNode right = generator.generate(arguments.get(1), Optional.empty()); block.append(left); @@ -97,6 +99,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.append(ifRightIsNull) .visitLabel(end); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generator, returnType, output))); return block; } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/BindCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/BindCodeGenerator.java index 5e2ba48cebda..208899da52c8 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/BindCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/BindCodeGenerator.java @@ -15,6 +15,7 @@ package io.prestosql.sql.gen; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.Variable; import io.prestosql.metadata.Signature; import io.prestosql.spi.type.Type; import io.prestosql.sql.gen.LambdaBytecodeGenerator.CompiledLambda; @@ -23,7 +24,9 @@ import java.util.List; import java.util.Map; +import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; public class BindCodeGenerator @@ -39,12 +42,18 @@ public BindCodeGenerator(Map compile } @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments, Optional outputBlockVariable) { // Bind expression is used to generate captured lambda. // It takes the captured values and the uncaptured lambda, and produces captured lambda as the output. // The uncaptured lambda is just a method, and does not have a stack representation during execution. // As a result, the bind expression generates the captured lambda in one step. + + // outputBlockVariable cannot present because + // 1. bind cannot be in the top level of an expression + // 2. lambda cannot be put into blocks. + checkArgument(!outputBlockVariable.isPresent()); + int numCaptures = arguments.size() - 1; LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) arguments.get(numCaptures); checkState(compiledLambdaMap.containsKey(lambda), "lambda expressions map does not contain this lambda definition"); diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeGenerator.java index 7fd20af19c4f..5f0472a28bdc 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeGenerator.java @@ -14,13 +14,25 @@ package io.prestosql.sql.gen; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.Variable; import io.prestosql.metadata.Signature; import io.prestosql.spi.type.Type; import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; public interface BytecodeGenerator { - BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments); + BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments, Optional outputBlockVariable); + + static BytecodeNode generateWrite(BytecodeGeneratorContext context, Type returnType, Variable outputBlock) + { + return BytecodeUtils.generateWrite( + context.getCallSiteBinder(), + context.getScope(), + context.getScope().getVariable("wasNull"), + returnType, + outputBlock); + } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeGeneratorContext.java b/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeGeneratorContext.java index 114ec93a7c86..e6450825e45c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeGeneratorContext.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeGeneratorContext.java @@ -19,6 +19,7 @@ import io.airlift.bytecode.Variable; import io.prestosql.metadata.FunctionRegistry; import io.prestosql.operator.scalar.ScalarFunctionImplementation; +import io.prestosql.sql.gen.BytecodeUtils.OutputBlockVariableAndType; import io.prestosql.sql.relational.RowExpression; import java.util.List; @@ -67,14 +68,14 @@ public CallSiteBinder getCallSiteBinder() return callSiteBinder; } - public BytecodeNode generate(RowExpression expression) + public BytecodeNode generate(RowExpression expression, Optional outputBlockVariable) { - return generate(expression, Optional.empty()); + return generate(expression, outputBlockVariable, Optional.empty()); } - public BytecodeNode generate(RowExpression expression, Optional lambdaInterface) + public BytecodeNode generate(RowExpression expression, Optional outputBlockVariable, Optional lambdaInterface) { - return rowExpressionCompiler.compile(expression, scope, lambdaInterface); + return rowExpressionCompiler.compile(expression, scope, outputBlockVariable, lambdaInterface); } public FunctionRegistry getRegistry() @@ -82,17 +83,26 @@ public FunctionRegistry getRegistry() return registry; } + public BytecodeNode generateCall(String name, ScalarFunctionImplementation function, List arguments) + { + return generateCall(name, function, arguments, Optional.empty()); + } + /** * Generates a function call with null handling, automatic binding of session parameter, etc. */ - public BytecodeNode generateCall(String name, ScalarFunctionImplementation function, List arguments) + public BytecodeNode generateCall( + String name, + ScalarFunctionImplementation function, + List arguments, + Optional outputBlockVariableAndType) { Optional instance = Optional.empty(); if (function.getInstanceFactory().isPresent()) { FieldDefinition field = cachedInstanceBinder.getCachedInstance(function.getInstanceFactory().get()); instance = Optional.of(scope.getThis().getField(field)); } - return generateInvocation(scope, name, function, instance, arguments, callSiteBinder); + return generateInvocation(scope, name, function, instance, arguments, callSiteBinder, outputBlockVariableAndType); } public Variable wasNull() diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeUtils.java b/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeUtils.java index e9ee3c79dc03..8992b2c47d94 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeUtils.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeUtils.java @@ -48,8 +48,10 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentType.VALUE_TYPE; +import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention.PROVIDED_BLOCKBUILDER; import static io.prestosql.sql.gen.Bootstrap.BOOTSTRAP_METHOD; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; public final class BytecodeUtils { @@ -59,25 +61,35 @@ private BytecodeUtils() public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class returnType, Class... stackArgsToPop) { - return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), false); + return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), Optional.empty(), false); } - public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class returnType, Iterable> stackArgsToPop) + public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class methodReturnType, Iterable> stackArgsToPop) { - return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), false); + return handleNullValue(scope, label, methodReturnType, ImmutableList.copyOf(stackArgsToPop), Optional.empty(), false); } - public static BytecodeNode ifWasNullClearPopAndGoto(Scope scope, LabelNode label, Class returnType, Class... stackArgsToPop) + public static BytecodeNode ifWasNullClearPopAndGoto(Scope scope, LabelNode label, Class methodReturnType, Class... stackArgsToPop) { - return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), true); + return handleNullValue(scope, label, methodReturnType, ImmutableList.copyOf(stackArgsToPop), Optional.empty(), true); + } + + public static BytecodeNode ifWasNullClearPopAppendAndGoto(Scope scope, LabelNode label, Class methodReturnType, Variable outputBlockVariable, Iterable> stackArgsToPop) + { + return handleNullValue(scope, label, methodReturnType, ImmutableList.copyOf(stackArgsToPop), Optional.of(outputBlockVariable), true); } public static BytecodeNode handleNullValue(Scope scope, LabelNode label, - Class returnType, + Class methodReturnType, List> stackArgsToPop, + Optional outputBlockVariable, boolean clearNullFlag) { + if (outputBlockVariable.isPresent()) { + checkArgument(methodReturnType == void.class); + } + Variable wasNull = scope.getVariable("wasNull"); BytecodeBlock nullCheck = new BytecodeBlock() @@ -95,9 +107,17 @@ public static BytecodeNode handleNullValue(Scope scope, isNull.pop(parameterType); } - isNull.pushJavaDefault(returnType); - String loadDefaultComment; - loadDefaultComment = format("loadJavaDefault(%s)", returnType.getName()); + String loadDefaultOrAppendNullComment; + if (!outputBlockVariable.isPresent()) { + isNull.pushJavaDefault(methodReturnType); + loadDefaultOrAppendNullComment = format("loadJavaDefault(%s)", methodReturnType.getName()); + } + else { + isNull.append(outputBlockVariable.get() + .invoke("appendNull", BlockBuilder.class) + .pop()); + loadDefaultOrAppendNullComment = "appendNullToOutputBlock"; + } isNull.gotoLabel(label); @@ -106,7 +126,7 @@ public static BytecodeNode handleNullValue(Scope scope, popComment = format("pop(%s)", Joiner.on(", ").join(stackArgsToPop)); } - return new IfStatement("if wasNull then %s", Joiner.on(", ").skipNulls().join(clearComment, popComment, loadDefaultComment, "goto " + label.getLabel())) + return new IfStatement("if wasNull then %s", Joiner.on(", ").skipNulls().join(clearComment, popComment, loadDefaultOrAppendNullComment, "goto " + label.getLabel())) .condition(nullCheck) .ifTrue(isNull); } @@ -160,7 +180,32 @@ public static BytecodeExpression loadConstant(Binding binding) binding.getType().returnType()); } - public static BytecodeNode generateInvocation(Scope scope, String name, ScalarFunctionImplementation function, Optional instance, List arguments, CallSiteBinder binder) + public static BytecodeNode generateInvocation( + Scope scope, + String name, + ScalarFunctionImplementation function, + Optional instance, + List arguments, + CallSiteBinder binder) + { + return generateInvocation( + scope, + name, + function, + instance, + arguments, + binder, + Optional.empty()); + } + + public static BytecodeNode generateInvocation( + Scope scope, + String name, + ScalarFunctionImplementation function, + Optional instance, + List arguments, + CallSiteBinder binder, + Optional outputBlockVariableAndType) { LabelNode end = new LabelNode("end"); BytecodeBlock block = new BytecodeBlock() @@ -186,7 +231,8 @@ public static BytecodeNode generateInvocation(Scope scope, String name, ScalarFu if (currentChoice.getArgumentProperty(i).getArgumentType() != VALUE_TYPE) { continue; } - if (!(arguments.get(i) instanceof InputReferenceNode) && currentChoice.getArgumentProperty(i).getNullConvention() == NullConvention.BLOCK_AND_POSITION) { + if (currentChoice.getArgumentProperty(i).getNullConvention() == NullConvention.BLOCK_AND_POSITION && !(arguments.get(i) instanceof InputReferenceNode) + || currentChoice.getReturnPlaceConvention() == PROVIDED_BLOCKBUILDER && (!outputBlockVariableAndType.isPresent())) { isValid = false; break; } @@ -215,6 +261,9 @@ public static BytecodeNode generateInvocation(Scope scope, String name, ScalarFu else if (type == ConnectorSession.class) { block.append(scope.getVariable("session")); } + else if (type == BlockBuilder.class) { + block.append(outputBlockVariableAndType.get().getOutputBlockVariable()); + } else { ArgumentProperty argumentProperty = bestChoice.getArgumentProperty(realParameterIndex); switch (argumentProperty.getArgumentType()) { @@ -224,7 +273,17 @@ else if (type == ConnectorSession.class) { case RETURN_NULL_ON_NULL: block.append(arguments.get(realParameterIndex)); checkArgument(!Primitives.isWrapperType(type), "Non-nullable argument must not be primitive wrapper type"); - block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes))); + switch (bestChoice.getReturnPlaceConvention()) { + case STACK: + block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes))); + break; + case PROVIDED_BLOCKBUILDER: + checkArgument(unboxedReturnType == void.class); + block.append(ifWasNullClearPopAppendAndGoto(scope, end, unboxedReturnType, outputBlockVariableAndType.get().getOutputBlockVariable(), Lists.reverse(stackTypes))); + break; + default: + throw new UnsupportedOperationException(format("Unsupported return place convention: %s", bestChoice.getReturnPlaceConvention())); + } break; case USE_NULL_FLAG: block.append(arguments.get(realParameterIndex)); @@ -261,10 +320,31 @@ else if (type == ConnectorSession.class) { block.append(invoke(binding, name)); if (function.isNullable()) { - block.append(unboxPrimitiveIfNecessary(scope, returnType)); + switch (bestChoice.getReturnPlaceConvention()) { + case STACK: + block.append(unboxPrimitiveIfNecessary(scope, returnType)); + break; + case PROVIDED_BLOCKBUILDER: + // no-op + break; + default: + throw new UnsupportedOperationException(format("Unsupported return place convention: %s", bestChoice.getReturnPlaceConvention())); + } } block.visitLabel(end); + if (outputBlockVariableAndType.isPresent()) { + switch (bestChoice.getReturnPlaceConvention()) { + case STACK: + block.append(generateWrite(binder, scope, scope.getVariable("wasNull"), outputBlockVariableAndType.get().getType(), outputBlockVariableAndType.get().getOutputBlockVariable())); + break; + case PROVIDED_BLOCKBUILDER: + // no-op + break; + default: + throw new UnsupportedOperationException(format("Unsupported return place convention: %s", bestChoice.getReturnPlaceConvention())); + } + } return block; } @@ -344,7 +424,12 @@ public static BytecodeExpression invoke(Binding binding, Signature signature) return invoke(binding, signature.getName()); } - public static BytecodeNode generateWrite(CallSiteBinder callSiteBinder, Scope scope, Variable wasNullVariable, Type type) + public static BytecodeNode generateWrite( + CallSiteBinder callSiteBinder, + Scope scope, + Variable wasNullVariable, + Type type, + Variable outputBlockVariable) { Class valueJavaType = type.getJavaType(); if (!valueJavaType.isPrimitive() && valueJavaType != Slice.class) { @@ -352,15 +437,8 @@ public static BytecodeNode generateWrite(CallSiteBinder callSiteBinder, Scope sc } String methodName = "write" + Primitives.wrap(valueJavaType).getSimpleName(); - // the stack contains [output, value] - - // We should be able to insert the code to get the output variable and compute the value - // at the right place instead of assuming they are in the stack. We should also not need to - // use temp variables to re-shuffle the stack to the right shape before Type.writeXXX is called - // Unfortunately, because of the assumptions made by try_cast, we can't get around it yet. - // TODO: clean up once try_cast is fixed + // the value to be written is at the top of stack Variable tempValue = scope.createTempVariable(valueJavaType); - Variable tempOutput = scope.createTempVariable(BlockBuilder.class); return new BytecodeBlock() .comment("if (wasNull)") .append(new IfStatement() @@ -368,15 +446,37 @@ public static BytecodeNode generateWrite(CallSiteBinder callSiteBinder, Scope sc .ifTrue(new BytecodeBlock() .comment("output.appendNull();") .pop(valueJavaType) + .getVariable(outputBlockVariable) .invokeInterface(BlockBuilder.class, "appendNull", BlockBuilder.class) .pop()) .ifFalse(new BytecodeBlock() .comment("%s.%s(output, %s)", type.getTypeSignature(), methodName, valueJavaType.getSimpleName()) .putVariable(tempValue) - .putVariable(tempOutput) .append(loadConstant(callSiteBinder.bind(type, Type.class))) - .getVariable(tempOutput) + .getVariable(outputBlockVariable) .getVariable(tempValue) .invokeInterface(Type.class, methodName, void.class, BlockBuilder.class, valueJavaType))); } + + public static class OutputBlockVariableAndType + { + private final Variable outputBlockVariable; + private final Type type; + + public OutputBlockVariableAndType(Variable outputBlockVariable, Type type) + { + this.outputBlockVariable = requireNonNull(outputBlockVariable); + this.type = requireNonNull(type); + } + + public Variable getOutputBlockVariable() + { + return outputBlockVariable; + } + + public Type getType() + { + return type; + } + } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/CastCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/CastCodeGenerator.java index 116fdd830f78..80b89191bad6 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/CastCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/CastCodeGenerator.java @@ -14,18 +14,23 @@ package io.prestosql.sql.gen; import com.google.common.collect.ImmutableList; +import io.airlift.bytecode.BytecodeBlock; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.Variable; import io.prestosql.metadata.Signature; import io.prestosql.spi.type.Type; import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; + +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; public class CastCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { RowExpression argument = arguments.get(0); @@ -33,6 +38,12 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .getRegistry() .getCoercion(argument.getType(), returnType); - return generatorContext.generateCall(function.getName(), generatorContext.getRegistry().getScalarFunctionImplementation(function), ImmutableList.of(generatorContext.generate(argument))); + BytecodeBlock block = new BytecodeBlock() + .append(generatorContext.generateCall( + function.getName(), + generatorContext.getRegistry().getScalarFunctionImplementation(function), + ImmutableList.of(generatorContext.generate(argument, Optional.empty())))); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); + return block; } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/CoalesceCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/CoalesceCodeGenerator.java index 75f70adaaf31..14c3d31b01c5 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/CoalesceCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/CoalesceCodeGenerator.java @@ -24,19 +24,21 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; public class CoalesceCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { List operands = new ArrayList<>(); for (RowExpression expression : arguments) { - operands.add(generatorContext.generate(expression)); + operands.add(generatorContext.generate(expression, Optional.empty())); } Variable wasNull = generatorContext.wasNull(); @@ -61,6 +63,9 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon nullValue = ifStatement; } - return nullValue; + BytecodeBlock block = new BytecodeBlock() + .append(nullValue); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); + return block; } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/CursorProcessorCompiler.java b/presto-main/src/main/java/io/prestosql/sql/gen/CursorProcessorCompiler.java index 13d13b9c514a..1a63478b82ee 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/CursorProcessorCompiler.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/CursorProcessorCompiler.java @@ -46,6 +46,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static io.airlift.bytecode.Access.PUBLIC; @@ -56,7 +57,6 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance; import static io.airlift.bytecode.expression.BytecodeExpressions.or; import static io.airlift.bytecode.instruction.JumpInstruction.jump; -import static io.prestosql.sql.gen.BytecodeUtils.generateWrite; import static io.prestosql.sql.gen.LambdaExpressionExtractor.extractLambdaExpressions; import static java.lang.String.format; @@ -247,7 +247,7 @@ private void generateFilterMethod( .comment("boolean wasNull = false;") .putVariable(wasNullVariable, false) .comment("evaluate filter: " + filter) - .append(compiler.compile(filter, scope)) + .append(compiler.compile(filter, scope, Optional.empty())) .comment("if (wasNull) return false;") .getVariable(wasNullVariable) .ifFalseGoto(end) @@ -285,10 +285,8 @@ private void generateProjectMethod( method.getBody() .comment("boolean wasNull = false;") .putVariable(wasNullVariable, false) - .getVariable(output) .comment("evaluate projection: " + projection.toString()) - .append(compiler.compile(projection, scope)) - .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) + .append(compiler.compile(projection, scope, Optional.of(output))) .ret(); } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/DereferenceCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/DereferenceCodeGenerator.java index 225a03217213..9906d641103e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/DereferenceCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/DereferenceCodeGenerator.java @@ -26,16 +26,18 @@ import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; import static io.prestosql.sql.gen.SqlTypeBytecodeExpression.constantType; public class DereferenceCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments, Optional outputBlockVariable) { checkArgument(arguments.size() == 2); CallSiteBinder callSiteBinder = generator.getCallSiteBinder(); @@ -47,7 +49,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon // clear the wasNull flag before evaluating the row value block.putVariable(wasNull, false); - block.append(generator.generate(arguments.get(0))).putVariable(rowBlock); + block.append(generator.generate(arguments.get(0), Optional.empty())).putVariable(rowBlock); IfStatement ifRowBlockIsNull = new IfStatement("if row block is null...") .condition(wasNull); @@ -84,6 +86,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.append(ifFieldIsNull) .visitLabel(end); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generator, returnType, output))); return block; } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/FunctionCallCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/FunctionCallCodeGenerator.java index 9125619356ba..f11b29318a46 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/FunctionCallCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/FunctionCallCodeGenerator.java @@ -14,10 +14,12 @@ package io.prestosql.sql.gen; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.Variable; import io.prestosql.metadata.FunctionRegistry; import io.prestosql.metadata.Signature; import io.prestosql.operator.scalar.ScalarFunctionImplementation; import io.prestosql.spi.type.Type; +import io.prestosql.sql.gen.BytecodeUtils.OutputBlockVariableAndType; import io.prestosql.sql.relational.RowExpression; import java.util.ArrayList; @@ -30,7 +32,7 @@ public class FunctionCallCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments, Optional outputBlockVariable) { FunctionRegistry registry = context.getRegistry(); @@ -41,13 +43,17 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon RowExpression argument = arguments.get(i); ScalarFunctionImplementation.ArgumentProperty argumentProperty = function.getArgumentProperty(i); if (argumentProperty.getArgumentType() == VALUE_TYPE) { - argumentsBytecode.add(context.generate(argument)); + argumentsBytecode.add(context.generate(argument, Optional.empty())); } else { - argumentsBytecode.add(context.generate(argument, Optional.of(argumentProperty.getLambdaInterface()))); + argumentsBytecode.add(context.generate(argument, Optional.empty(), Optional.of(argumentProperty.getLambdaInterface()))); } } - return context.generateCall(signature.getName(), function, argumentsBytecode); + return context.generateCall( + signature.getName(), + function, + argumentsBytecode, + outputBlockVariable.map(variable -> new OutputBlockVariableAndType(variable, returnType))); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/IfCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/IfCodeGenerator.java index f79714093e66..c03bacbfe6ee 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/IfCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/IfCodeGenerator.java @@ -23,29 +23,33 @@ import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; public class IfCodeGenerator implements BytecodeGenerator { - @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments, Optional outputBlockVariable) { Preconditions.checkArgument(arguments.size() == 3); Variable wasNull = context.wasNull(); BytecodeBlock condition = new BytecodeBlock() - .append(context.generate(arguments.get(0))) + .append(context.generate(arguments.get(0), Optional.empty())) .comment("... and condition value was not null") .append(wasNull) .invokeStatic(CompilerOperations.class, "not", boolean.class, boolean.class) .invokeStatic(CompilerOperations.class, "and", boolean.class, boolean.class, boolean.class) .append(wasNull.set(constantFalse())); - return new IfStatement() - .condition(condition) - .ifTrue(context.generate(arguments.get(1))) - .ifFalse(context.generate(arguments.get(2))); + BytecodeBlock block = new BytecodeBlock() + .append(new IfStatement() + .condition(condition) + .ifTrue(context.generate(arguments.get(1), Optional.empty())) + .ifFalse(context.generate(arguments.get(2), Optional.empty()))); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(context, returnType, output))); + return block; } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/InCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/InCodeGenerator.java index d5b575719180..afb84aab933e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/InCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/InCodeGenerator.java @@ -40,6 +40,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -50,6 +51,7 @@ import static io.airlift.bytecode.instruction.JumpInstruction.jump; import static io.prestosql.spi.function.OperatorType.HASH_CODE; import static io.prestosql.spi.function.OperatorType.INDETERMINATE; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; import static io.prestosql.sql.gen.BytecodeUtils.ifWasNullPopAndGoto; import static io.prestosql.sql.gen.BytecodeUtils.invoke; import static io.prestosql.sql.gen.BytecodeUtils.loadConstant; @@ -107,7 +109,7 @@ static SwitchGenerationCase checkSwitchGenerationCase(Type type, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { List values = arguments.subList(1, arguments.size()); // empty IN statements are not allowed by the standard, and not possible here @@ -129,7 +131,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon ImmutableSet.Builder constantValuesBuilder = ImmutableSet.builder(); for (RowExpression testValue : values) { - BytecodeNode testBytecode = generatorContext.generate(testValue); + BytecodeNode testBytecode = generatorContext.generate(testValue, Optional.empty()); if (isDeterminateConstant(testValue, isIndeterminateFunction.getMethodHandle())) { ConstantExpression constant = (ConstantExpression) testValue; @@ -252,7 +254,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon BytecodeBlock block = new BytecodeBlock() .comment("IN") - .append(generatorContext.generate(arguments.get(0))) + .append(generatorContext.generate(arguments.get(0), Optional.empty())) .append(ifWasNullPopAndGoto(scope, end, boolean.class, javaType)) .putVariable(value) .append(switchBlock) @@ -276,6 +278,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.visitLabel(end); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); return block; } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/InvokeFunctionBytecodeExpression.java b/presto-main/src/main/java/io/prestosql/sql/gen/InvokeFunctionBytecodeExpression.java index 9394ab9e2591..f70539ba5ba5 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/InvokeFunctionBytecodeExpression.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/InvokeFunctionBytecodeExpression.java @@ -65,7 +65,13 @@ private InvokeFunctionBytecodeExpression( { super(type(Primitives.unwrap(function.getMethodHandle().type().returnType()))); - this.invocation = generateInvocation(scope, name, function, instance, parameters.stream().map(BytecodeNode.class::cast).collect(toImmutableList()), binder); + this.invocation = generateInvocation( + scope, + name, + function, + instance, + parameters.stream().map(BytecodeNode.class::cast).collect(toImmutableList()), + binder); this.oneLineDescription = name + "(" + Joiner.on(", ").join(parameters) + ")"; } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/IsNullCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/IsNullCodeGenerator.java index e4ab54644b2e..6687f30ea9d7 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/IsNullCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/IsNullCodeGenerator.java @@ -22,16 +22,18 @@ import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.instruction.Constant.loadBoolean; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; import static io.prestosql.type.UnknownType.UNKNOWN; public class IsNullCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { Preconditions.checkArgument(arguments.size() == 1); @@ -40,7 +42,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon return loadBoolean(true); } - BytecodeNode value = generatorContext.generate(argument); + BytecodeNode value = generatorContext.generate(argument, Optional.empty()); // evaluate the expression, pop the produced value, and load the null flag Variable wasNull = generatorContext.wasNull(); @@ -53,6 +55,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon // clear the null flag block.append(wasNull.set(constantFalse())); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); return block; } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/JoinFilterFunctionCompiler.java b/presto-main/src/main/java/io/prestosql/sql/gen/JoinFilterFunctionCompiler.java index 14ca60e03a31..5a305b1c2759 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/JoinFilterFunctionCompiler.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/JoinFilterFunctionCompiler.java @@ -50,6 +50,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import static com.google.common.base.MoreObjects.toStringHelper; @@ -198,7 +199,7 @@ private void generateFilterMethod( metadata.getFunctionRegistry(), compiledLambdaMap); - BytecodeNode visitorBody = compiler.compile(filter, scope); + BytecodeNode visitorBody = compiler.compile(filter, scope, Optional.empty()); Variable result = scope.declareVariable(boolean.class, "result"); body.append(visitorBody) diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/LambdaBytecodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/LambdaBytecodeGenerator.java index 7cbc70c39c0d..64c70fb36233 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/LambdaBytecodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/LambdaBytecodeGenerator.java @@ -48,6 +48,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -156,7 +157,7 @@ private static CompiledLambda defineLambdaMethod( Scope scope = method.getScope(); Variable wasNull = scope.declareVariable(boolean.class, "wasNull"); - BytecodeNode compiledBody = innerExpressionCompiler.compile(lambda.getBody(), scope); + BytecodeNode compiledBody = innerExpressionCompiler.compile(lambda.getBody(), scope, Optional.empty()); method.getBody() .putVariable(wasNull, false) .append(compiledBody) @@ -197,7 +198,7 @@ public static BytecodeNode generateLambda( for (RowExpression captureExpression : captureExpressions) { Class valueType = Primitives.wrap(captureExpression.getType().getJavaType()); Variable valueVariable = scope.createTempVariable(valueType); - block.append(context.generate(captureExpression)); + block.append(context.generate(captureExpression, Optional.empty())); block.append(boxPrimitiveIfNecessary(scope, valueType)); block.putVariable(valueVariable); block.append(wasNull.set(constantFalse())); diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/NullIfCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/NullIfCodeGenerator.java index d953e2372a3e..f10c8aaa2b8b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/NullIfCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/NullIfCodeGenerator.java @@ -28,15 +28,17 @@ import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; import static io.prestosql.sql.gen.BytecodeUtils.ifWasNullPopAndGoto; public class NullIfCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { Scope scope = generatorContext.getScope(); @@ -49,7 +51,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon Variable firstValue = scope.createTempVariable(first.getType().getJavaType()); BytecodeBlock block = new BytecodeBlock() .comment("check if first arg is null") - .append(generatorContext.generate(first)) + .append(generatorContext.generate(first, Optional.empty())) .append(ifWasNullPopAndGoto(scope, notMatch, void.class)) .dup(first.getType().getJavaType()) .putVariable(firstValue); @@ -65,7 +67,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon equalsFunction, ImmutableList.of( cast(generatorContext, firstValue, firstType, equalsSignature.getArgumentTypes().get(0)), - cast(generatorContext, generatorContext.generate(second), secondType, equalsSignature.getArgumentTypes().get(1)))); + cast(generatorContext, generatorContext.generate(second, Optional.empty()), secondType, equalsSignature.getArgumentTypes().get(1)))); BytecodeBlock conditionBlock = new BytecodeBlock() .append(equalsCall) @@ -83,6 +85,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .ifTrue(trueBlock) .ifFalse(notMatch)); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); return block; } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/OrCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/OrCodeGenerator.java index b2ed3ea097bc..e2085e44c561 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/OrCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/OrCodeGenerator.java @@ -24,14 +24,16 @@ import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; public class OrCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments, Optional outputBlockVariable) { Preconditions.checkArgument(arguments.size() == 2); @@ -40,8 +42,8 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .comment("OR") .setDescription("OR"); - BytecodeNode left = generator.generate(arguments.get(0)); - BytecodeNode right = generator.generate(arguments.get(1)); + BytecodeNode left = generator.generate(arguments.get(0), Optional.empty()); + BytecodeNode right = generator.generate(arguments.get(1), Optional.empty()); block.append(left); @@ -96,6 +98,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.append(ifRightIsNull) .visitLabel(end); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generator, returnType, output))); return block; } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/PageFunctionCompiler.java b/presto-main/src/main/java/io/prestosql/sql/gen/PageFunctionCompiler.java index c5c50a676352..24a4354d487c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/PageFunctionCompiler.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/PageFunctionCompiler.java @@ -87,7 +87,6 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.not; import static io.prestosql.operator.project.PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters; import static io.prestosql.spi.StandardErrorCode.COMPILER_ERROR; -import static io.prestosql.sql.gen.BytecodeUtils.generateWrite; import static io.prestosql.sql.gen.BytecodeUtils.invoke; import static io.prestosql.sql.gen.LambdaExpressionExtractor.extractLambdaExpressions; import static io.prestosql.util.CompilerUtils.defineClass; @@ -342,7 +341,7 @@ private MethodDefinition generateEvaluateMethod( declareBlockVariables(projection, page, scope, body); - Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); + scope.declareVariable("wasNull", body, constantFalse()); RowExpressionCompiler compiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, @@ -350,9 +349,9 @@ private MethodDefinition generateEvaluateMethod( metadata.getFunctionRegistry(), compiledLambdaMap); - body.append(thisVariable.getField(blockBuilder)) - .append(compiler.compile(projection, scope)) - .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) + Variable outputBlockVariable = scope.createTempVariable(BlockBuilder.class); + body.append(outputBlockVariable.set(thisVariable.getField(blockBuilder))) + .append(compiler.compile(projection, scope, Optional.of(outputBlockVariable))) .ret(); return method; } @@ -525,7 +524,7 @@ private MethodDefinition generateFilterMethod( compiledLambdaMap); Variable result = scope.declareVariable(boolean.class, "result"); - body.append(compiler.compile(filter, scope)) + body.append(compiler.compile(filter, scope, Optional.empty())) // store result so we can check for null .putVariable(result) .append(and(not(wasNullVariable), result).ret()); diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/RowConstructorCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/RowConstructorCodeGenerator.java index 03570cf088be..75e73751d2ea 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/RowConstructorCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/RowConstructorCodeGenerator.java @@ -26,17 +26,19 @@ import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; import static io.airlift.bytecode.expression.BytecodeExpressions.constantNull; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; import static io.prestosql.sql.gen.SqlTypeBytecodeExpression.constantType; public class RowConstructorCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type rowType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type rowType, List arguments, Optional outputBlockVariable) { BytecodeBlock block = new BytecodeBlock().setDescription("Constructor for " + rowType.toString()); CallSiteBinder binder = context.getCallSiteBinder(); @@ -59,7 +61,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon Variable field = scope.createTempVariable(fieldType.getJavaType()); block.comment("Clean wasNull and Generate + " + i + "-th field of row"); block.append(context.wasNull().set(constantFalse())); - block.append(context.generate(arguments.get(i))); + block.append(context.generate(arguments.get(i), Optional.empty())); block.putVariable(field); block.append(new IfStatement() .condition(context.wasNull()) @@ -71,6 +73,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.append(constantType(binder, rowType).invoke("getObject", Object.class, blockBuilder.cast(Block.class), constantInt(0)) .cast(Block.class)); block.append(context.wasNull().set(constantFalse())); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(context, rowType, output))); return block; } } diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/RowExpressionCompiler.java b/presto-main/src/main/java/io/prestosql/sql/gen/RowExpressionCompiler.java index b4f2dd38559a..dd4baa7b6f58 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/RowExpressionCompiler.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/RowExpressionCompiler.java @@ -18,6 +18,7 @@ import io.airlift.bytecode.BytecodeBlock; import io.airlift.bytecode.BytecodeNode; import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; import io.prestosql.metadata.FunctionRegistry; import io.prestosql.sql.gen.LambdaBytecodeGenerator.CompiledLambda; import io.prestosql.sql.relational.CallExpression; @@ -31,6 +32,7 @@ import java.util.Map; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; import static io.airlift.bytecode.instruction.Constant.loadBoolean; @@ -39,6 +41,7 @@ import static io.airlift.bytecode.instruction.Constant.loadInt; import static io.airlift.bytecode.instruction.Constant.loadLong; import static io.airlift.bytecode.instruction.Constant.loadString; +import static io.prestosql.sql.gen.BytecodeUtils.generateWrite; import static io.prestosql.sql.gen.BytecodeUtils.loadConstant; import static io.prestosql.sql.gen.LambdaBytecodeGenerator.generateLambda; import static io.prestosql.sql.relational.Signatures.BIND; @@ -74,14 +77,16 @@ public class RowExpressionCompiler this.compiledLambdaMap = compiledLambdaMap; } - public BytecodeNode compile(RowExpression rowExpression, Scope scope) + public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional outputBlockVariable) { - return compile(rowExpression, scope, Optional.empty()); + return compile(rowExpression, scope, outputBlockVariable, Optional.empty()); } - public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional lambdaInterface) + // When outputBlockVariable is presented, the generated bytecode will write the evaluated value into the outputBlockVariable, + // otherwise the value will be left on stack. + public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional outputBlockVariable, Optional lambdaInterface) { - return rowExpression.accept(new Visitor(), new Context(scope, lambdaInterface)); + return rowExpression.accept(new Visitor(), new Context(scope, outputBlockVariable, lambdaInterface)); } private class Visitor @@ -147,7 +152,7 @@ public BytecodeNode visitCall(CallExpression call, Context context) cachedInstanceBinder, registry); - return generator.generateExpression(call.getSignature(), generatorContext, call.getType(), call.getArguments()); + return generator.generateExpression(call.getSignature(), generatorContext, call.getType(), call.getArguments(), context.getOutputBlockVariable()); } @Override @@ -158,50 +163,76 @@ public BytecodeNode visitConstant(ConstantExpression constant, Context context) BytecodeBlock block = new BytecodeBlock(); if (value == null) { - return block.comment("constant null") + block.comment("constant null") .append(context.getScope().getVariable("wasNull").set(constantTrue())) .pushJavaDefault(javaType); } + else { + // use LDC for primitives (boolean, short, int, long, float, double) + block.comment("constant " + constant.getType().getTypeSignature()); + if (javaType == boolean.class) { + block.append(loadBoolean((Boolean) value)); + } + else if (javaType == byte.class || javaType == short.class || javaType == int.class) { + block.append(loadInt(((Number) value).intValue())); + } + else if (javaType == long.class) { + block.append(loadLong((Long) value)); + } + else if (javaType == float.class) { + block.append(loadFloat((Float) value)); + } + else if (javaType == double.class) { + block.append(loadDouble((Double) value)); + } + else if (javaType == String.class) { + block.append(loadString((String) value)); + } + else { + // bind constant object directly into the call-site using invoke dynamic + Binding binding = callSiteBinder.bind(value, constant.getType().getJavaType()); - // use LDC for primitives (boolean, short, int, long, float, double) - block.comment("constant " + constant.getType().getTypeSignature()); - if (javaType == boolean.class) { - return block.append(loadBoolean((Boolean) value)); - } - if (javaType == byte.class || javaType == short.class || javaType == int.class) { - return block.append(loadInt(((Number) value).intValue())); - } - if (javaType == long.class) { - return block.append(loadLong((Long) value)); - } - if (javaType == float.class) { - return block.append(loadFloat((Float) value)); - } - if (javaType == double.class) { - return block.append(loadDouble((Double) value)); - } - if (javaType == String.class) { - return block.append(loadString((String) value)); + block = new BytecodeBlock() + .setDescription("constant " + constant.getType()) + .comment(constant.toString()) + .append(loadConstant(binding)); + } } - // bind constant object directly into the call-site using invoke dynamic - Binding binding = callSiteBinder.bind(value, constant.getType().getJavaType()); + if (context.getOutputBlockVariable().isPresent()) { + block.append(generateWrite( + callSiteBinder, + context.getScope(), + context.getScope().getVariable("wasNull"), + constant.getType(), + context.getOutputBlockVariable().get())); + } - return new BytecodeBlock() - .setDescription("constant " + constant.getType()) - .comment(constant.toString()) - .append(loadConstant(binding)); + return block; } @Override public BytecodeNode visitInputReference(InputReferenceExpression node, Context context) { - return fieldReferenceCompiler.visitInputReference(node, context.getScope()); + BytecodeNode inputReferenceBytecode = fieldReferenceCompiler.visitInputReference(node, context.getScope()); + if (!context.getOutputBlockVariable().isPresent()) { + return inputReferenceBytecode; + } + + return new BytecodeBlock() + .append(inputReferenceBytecode) + .append(generateWrite( + callSiteBinder, + context.getScope(), + context.getScope().getVariable("wasNull"), + node.getType(), + context.getOutputBlockVariable().get())); } @Override public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context context) { + checkArgument(!context.getOutputBlockVariable().isPresent(), "lambda definition expression does not support writing to block"); checkState(compiledLambdaMap.containsKey(lambda), "lambda expressions map does not contain this lambda definition"); if (!context.lambdaInterface.get().isAnnotationPresent(FunctionalInterface.class)) { // lambdaInterface is checked to be annotated with FunctionalInterface when generating ScalarFunctionImplementation @@ -225,18 +256,32 @@ public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context conte @Override public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Context context) { - return fieldReferenceCompiler.visitVariableReference(reference, context.getScope()); + BytecodeNode variableReferenceByteCode = fieldReferenceCompiler.visitVariableReference(reference, context.getScope()); + if (!context.getOutputBlockVariable().isPresent()) { + return variableReferenceByteCode; + } + + return new BytecodeBlock() + .append(variableReferenceByteCode) + .append(generateWrite( + callSiteBinder, + context.getScope(), + context.getScope().getVariable("wasNull"), + reference.getType(), + context.getOutputBlockVariable().get())); } } private static class Context { private final Scope scope; + private final Optional outputBlockVariable; private final Optional lambdaInterface; - public Context(Scope scope, Optional lambdaInterface) + public Context(Scope scope, Optional outputBlockVariable, Optional lambdaInterface) { this.scope = scope; + this.outputBlockVariable = outputBlockVariable; this.lambdaInterface = lambdaInterface; } @@ -245,6 +290,11 @@ public Scope getScope() return scope; } + public Optional getOutputBlockVariable() + { + return outputBlockVariable; + } + public Optional getLambdaInterface() { return lambdaInterface; diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/SwitchCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/SwitchCodeGenerator.java index 1de5e002be4d..5fb73372417c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/SwitchCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/SwitchCodeGenerator.java @@ -30,15 +30,17 @@ import io.prestosql.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; +import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite; public class SwitchCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { // TODO: compile as /* @@ -74,7 +76,7 @@ else if ( == ) { // process value, else, and all when clauses RowExpression value = arguments.get(0); - BytecodeNode valueBytecode = generatorContext.generate(value); + BytecodeNode valueBytecode = generatorContext.generate(value, Optional.empty()); BytecodeNode elseValue; List whenClauses; @@ -87,7 +89,7 @@ else if ( == ) { } else { whenClauses = arguments.subList(1, arguments.size() - 1); - elseValue = generatorContext.generate(last); + elseValue = generatorContext.generate(last, Optional.empty()); } // determine the type of the value and result @@ -123,7 +125,7 @@ else if ( == ) { BytecodeNode equalsCall = generatorContext.generateCall( equalsFunction.getName(), generatorContext.getRegistry().getScalarFunctionImplementation(equalsFunction), - ImmutableList.of(generatorContext.generate(operand), getTempVariableNode)); + ImmutableList.of(generatorContext.generate(operand, Optional.empty()), getTempVariableNode)); BytecodeBlock condition = new BytecodeBlock() .append(equalsCall) @@ -131,10 +133,12 @@ else if ( == ) { elseValue = new IfStatement("when") .condition(condition) - .ifTrue(generatorContext.generate(result)) + .ifTrue(generatorContext.generate(result, Optional.empty())) .ifFalse(elseValue); } - return block.append(elseValue); + block.append(elseValue); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); + return block; } } diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/TestProvidedBlockBuilderReturnPlaceConvention.java b/presto-main/src/test/java/io/prestosql/operator/scalar/TestProvidedBlockBuilderReturnPlaceConvention.java new file mode 100644 index 000000000000..25349b600c69 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/TestProvidedBlockBuilderReturnPlaceConvention.java @@ -0,0 +1,401 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.prestosql.metadata.BoundVariables; +import io.prestosql.metadata.FunctionKind; +import io.prestosql.metadata.FunctionRegistry; +import io.prestosql.metadata.Signature; +import io.prestosql.metadata.SqlScalarFunction; +import io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; +import io.prestosql.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.type.ArrayType; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.TypeManager; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; + +import static com.google.common.primitives.Primitives.wrap; +import static io.prestosql.metadata.Signature.typeVariable; +import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; +import static io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; +import static io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_BOXED_TYPE; +import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention.PROVIDED_BLOCKBUILDER; +import static io.prestosql.operator.scalar.TestProvidedBlockBuilderReturnPlaceConvention.FunctionWithProvidedBlockReturnPlaceConvention1.PROVIDED_BLOCKBUILDER_CONVENTION1; +import static io.prestosql.operator.scalar.TestProvidedBlockBuilderReturnPlaceConvention.FunctionWithProvidedBlockReturnPlaceConvention2.PROVIDED_BLOCKBUILDER_CONVENTION2; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.TypeSignature.parseTypeSignature; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.spi.type.VarcharType.createVarcharType; +import static io.prestosql.util.Reflection.methodHandle; +import static org.testng.Assert.assertTrue; + +public class TestProvidedBlockBuilderReturnPlaceConvention + extends AbstractTestFunctions +{ + @BeforeClass + public void setUp() + { + registerScalarFunction(PROVIDED_BLOCKBUILDER_CONVENTION1); + registerScalarFunction(PROVIDED_BLOCKBUILDER_CONVENTION2); + } + + @Test + public void testProvidedBlockBuilderReturnNullOnNull() + { + assertFunction("identity1(123)", INTEGER, 123); + assertFunction("identity1(identity1(123))", INTEGER, 123); + assertFunction("identity1(CAST(null AS INTEGER))", INTEGER, null); + assertFunction("identity1(identity1(CAST(null AS INTEGER)))", INTEGER, null); + + assertFunction("identity1(123.4E0)", DOUBLE, 123.4); + assertFunction("identity1(identity1(123.4E0))", DOUBLE, 123.4); + assertFunction("identity1(CAST(null AS DOUBLE))", DOUBLE, null); + assertFunction("identity1(identity1(CAST(null AS DOUBLE)))", DOUBLE, null); + + assertFunction("identity1(true)", BOOLEAN, true); + assertFunction("identity1(identity1(true))", BOOLEAN, true); + assertFunction("identity1(CAST(null AS BOOLEAN))", BOOLEAN, null); + assertFunction("identity1(identity1(CAST(null AS BOOLEAN)))", BOOLEAN, null); + + assertFunction("identity1('abc')", createVarcharType(3), "abc"); + assertFunction("identity1(identity1('abc'))", createVarcharType(3), "abc"); + assertFunction("identity1(CAST(null AS VARCHAR))", VARCHAR, null); + assertFunction("identity1(identity1(CAST(null AS VARCHAR)))", VARCHAR, null); + + assertFunction("identity1(ARRAY[1,2,3])", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("identity1(identity1(ARRAY[1,2,3]))", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("identity1(CAST(null AS ARRAY))", new ArrayType(INTEGER), null); + assertFunction("identity1(identity1(CAST(null AS ARRAY)))", new ArrayType(INTEGER), null); + + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderLong.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderDouble.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderBoolean.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderSlice.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderBlock.get() > 0); + } + + @Test + public void testProvidedBlockBuilderUseBoxedType() + { + assertFunction("identity2(123)", INTEGER, 123); + assertFunction("identity2(identity2(123))", INTEGER, 123); + assertFunction("identity2(CAST(null AS INTEGER))", INTEGER, null); + assertFunction("identity2(identity2(CAST(null AS INTEGER)))", INTEGER, null); + + assertFunction("identity2(123.4E0)", DOUBLE, 123.4); + assertFunction("identity2(identity2(123.4E0))", DOUBLE, 123.4); + assertFunction("identity2(CAST(null AS DOUBLE))", DOUBLE, null); + assertFunction("identity2(identity2(CAST(null AS DOUBLE)))", DOUBLE, null); + + assertFunction("identity2(true)", BOOLEAN, true); + assertFunction("identity2(identity2(true))", BOOLEAN, true); + assertFunction("identity2(CAST(null AS BOOLEAN))", BOOLEAN, null); + assertFunction("identity2(identity2(CAST(null AS BOOLEAN)))", BOOLEAN, null); + + assertFunction("identity2('abc')", createVarcharType(3), "abc"); + assertFunction("identity2(identity2('abc'))", createVarcharType(3), "abc"); + assertFunction("identity2(CAST(null AS VARCHAR))", VARCHAR, null); + assertFunction("identity2(identity2(CAST(null AS VARCHAR)))", VARCHAR, null); + + assertFunction("identity2(ARRAY[1,2,3])", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("identity2(identity2(ARRAY[1,2,3]))", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("identity2(CAST(null AS ARRAY))", new ArrayType(INTEGER), null); + assertFunction("identity2(identity2(CAST(null AS ARRAY)))", new ArrayType(INTEGER), null); + + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderLong.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderDouble.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderBoolean.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderSlice.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderBlock.get() > 0); + } + + // null convention RETURN_NULL_ON_NULL + public static class FunctionWithProvidedBlockReturnPlaceConvention1 + extends SqlScalarFunction + { + private static final AtomicLong hitProvidedBlockBuilderLong = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderDouble = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderBoolean = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderSlice = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderBlock = new AtomicLong(); + + public static final FunctionWithProvidedBlockReturnPlaceConvention1 PROVIDED_BLOCKBUILDER_CONVENTION1 = new FunctionWithProvidedBlockReturnPlaceConvention1(); + + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_LONG = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockLong", Type.class, BlockBuilder.class, long.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_DOUBLE = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockDouble", Type.class, BlockBuilder.class, double.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_BOOLEAN = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockBoolean", Type.class, BlockBuilder.class, boolean.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_SLICE = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockSlice", Type.class, BlockBuilder.class, Slice.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_BLOCK = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockBlock", Type.class, BlockBuilder.class, Block.class); + + protected FunctionWithProvidedBlockReturnPlaceConvention1() + { + super(new Signature( + "identity1", + FunctionKind.SCALAR, + ImmutableList.of(typeVariable("T")), + ImmutableList.of(), + parseTypeSignature("T"), + ImmutableList.of(parseTypeSignature("T")), + false)); + } + + @Override + public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) + { + Type type = boundVariables.getTypeVariable("T"); + MethodHandle methodHandleStack = MethodHandles.identity(type.getJavaType()); + MethodHandle methodHandleProvidedBlock; + if (type.getJavaType() == long.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_LONG.bindTo(type); + } + else if (type.getJavaType() == double.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_DOUBLE.bindTo(type); + } + else if (type.getJavaType() == boolean.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_BOOLEAN.bindTo(type); + } + else if (type.getJavaType() == Slice.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_SLICE.bindTo(type); + } + else if (type.getJavaType() == Block.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_BLOCK.bindTo(type); + } + else { + throw new UnsupportedOperationException(); + } + + return new ScalarFunctionImplementation( + ImmutableList.of( + new ScalarImplementationChoice( + false, + ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), + ReturnPlaceConvention.STACK, + methodHandleStack, + Optional.empty()), + new ScalarImplementationChoice( + false, + ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), + PROVIDED_BLOCKBUILDER, + methodHandleProvidedBlock, + Optional.empty())), + isDeterministic()); + } + + public static void providedBlockLong(Type type, BlockBuilder output, long value) + { + hitProvidedBlockBuilderLong.incrementAndGet(); + type.writeLong(output, value); + } + + public static void providedBlockDouble(Type type, BlockBuilder output, double value) + { + hitProvidedBlockBuilderDouble.incrementAndGet(); + type.writeDouble(output, value); + } + + public static void providedBlockBoolean(Type type, BlockBuilder output, boolean value) + { + hitProvidedBlockBuilderBoolean.incrementAndGet(); + type.writeBoolean(output, value); + } + + public static void providedBlockSlice(Type type, BlockBuilder output, Slice value) + { + hitProvidedBlockBuilderSlice.incrementAndGet(); + type.writeSlice(output, value); + } + + public static void providedBlockBlock(Type type, BlockBuilder output, Block value) + { + hitProvidedBlockBuilderBlock.incrementAndGet(); + type.writeObject(output, value); + } + + @Override + public boolean isDeterministic() + { + return true; + } + + @Override + public boolean isHidden() + { + return false; + } + + @Override + public String getDescription() + { + return ""; + } + } + + // null convention USE_BOXED_TYPE + public static class FunctionWithProvidedBlockReturnPlaceConvention2 + extends SqlScalarFunction + { + private static final AtomicLong hitProvidedBlockBuilderLong = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderDouble = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderBoolean = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderSlice = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderBlock = new AtomicLong(); + + public static final FunctionWithProvidedBlockReturnPlaceConvention2 PROVIDED_BLOCKBUILDER_CONVENTION2 = new FunctionWithProvidedBlockReturnPlaceConvention2(); + + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_LONG = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockLong", Type.class, BlockBuilder.class, Long.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_DOUBLE = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockDouble", Type.class, BlockBuilder.class, Double.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_BOOLEAN = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockBoolean", Type.class, BlockBuilder.class, Boolean.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_SLICE = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockSlice", Type.class, BlockBuilder.class, Slice.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_BLOCK = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockBlock", Type.class, BlockBuilder.class, Block.class); + + protected FunctionWithProvidedBlockReturnPlaceConvention2() + { + super(new Signature( + "identity2", + FunctionKind.SCALAR, + ImmutableList.of(typeVariable("T")), + ImmutableList.of(), + parseTypeSignature("T"), + ImmutableList.of(parseTypeSignature("T")), + false)); + } + + @Override + public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) + { + Type type = boundVariables.getTypeVariable("T"); + MethodHandle methodHandleStack = MethodHandles.identity(wrap(type.getJavaType())); + MethodHandle methodHandleProvidedBlock; + if (type.getJavaType() == long.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_LONG.bindTo(type); + } + else if (type.getJavaType() == double.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_DOUBLE.bindTo(type); + } + else if (type.getJavaType() == boolean.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_BOOLEAN.bindTo(type); + } + else if (type.getJavaType() == Slice.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_SLICE.bindTo(type); + } + else if (type.getJavaType() == Block.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_BLOCK.bindTo(type); + } + else { + throw new UnsupportedOperationException(); + } + + return new ScalarFunctionImplementation( + ImmutableList.of( + new ScalarImplementationChoice( + true, + ImmutableList.of(valueTypeArgumentProperty(USE_BOXED_TYPE)), + ReturnPlaceConvention.STACK, + methodHandleStack, + Optional.empty()), + new ScalarImplementationChoice( + true, + ImmutableList.of(valueTypeArgumentProperty(USE_BOXED_TYPE)), + PROVIDED_BLOCKBUILDER, + methodHandleProvidedBlock, + Optional.empty())), + isDeterministic()); + } + + public static void providedBlockLong(Type type, BlockBuilder output, Long value) + { + hitProvidedBlockBuilderLong.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeLong(output, value); + } + } + + public static void providedBlockDouble(Type type, BlockBuilder output, Double value) + { + hitProvidedBlockBuilderDouble.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeDouble(output, value); + } + } + + public static void providedBlockBoolean(Type type, BlockBuilder output, Boolean value) + { + hitProvidedBlockBuilderBoolean.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeBoolean(output, value); + } + } + + public static void providedBlockSlice(Type type, BlockBuilder output, Slice value) + { + hitProvidedBlockBuilderSlice.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeSlice(output, value); + } + } + + public static void providedBlockBlock(Type type, BlockBuilder output, Block value) + { + hitProvidedBlockBuilderBlock.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeObject(output, value); + } + } + + @Override + public boolean isDeterministic() + { + return true; + } + + @Override + public boolean isHidden() + { + return false; + } + + @Override + public String getDescription() + { + return ""; + } + } +}