From 3e60572b324c48a2d66d8ff9f5b684ab6862a5cf Mon Sep 17 00:00:00 2001 From: Wenlei Xie Date: Tue, 25 Dec 2018 17:29:20 -0800 Subject: [PATCH 1/4] Introduce ReturnPlaceConvention for scalar function --- .../metadata/PolymorphicScalarFunction.java | 11 +++++++++- .../PolymorphicScalarFunctionBuilder.java | 17 +++++++++++++- .../scalar/RowDistinctFromOperator.java | 3 +++ .../scalar/ScalarFunctionImplementation.java | 22 ++++++++++++++++++- .../ParametricScalarImplementation.java | 20 ++++++++++++++++- .../presto/sql/gen/BytecodeUtils.java | 6 +++++ 6 files changed, 75 insertions(+), 4 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/PolymorphicScalarFunction.java b/presto-main/src/main/java/com/facebook/presto/metadata/PolymorphicScalarFunction.java index f7820174df591..7e7dcd4d03d5c 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/PolymorphicScalarFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/PolymorphicScalarFunction.java @@ -19,6 +19,7 @@ import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -124,7 +125,7 @@ private ScalarImplementationChoice getScalarFunctionImplementationChoice( List extraParameters = computeExtraParameters(matchingMethodsGroup.get(), context); MethodHandle methodHandle = applyExtraParameters(matchingMethod.get().getMethod(), extraParameters, choice.getArgumentProperties()); - return new ScalarImplementationChoice(choice.isNullableResult(), choice.getArgumentProperties(), methodHandle, Optional.empty()); + return new ScalarImplementationChoice(choice.isNullableResult(), choice.getArgumentProperties(), choice.getReturnPlaceConvention(), methodHandle, Optional.empty()); } private static boolean matchesParameterAndReturnTypes( @@ -218,15 +219,18 @@ static final class PolymorphicScalarFunctionChoice { private final boolean nullableResult; private final List argumentProperties; + private final ReturnPlaceConvention returnPlaceConvention; private final List methodsGroups; PolymorphicScalarFunctionChoice( boolean nullableResult, List argumentProperties, + ReturnPlaceConvention returnPlaceConvention, List methodsGroups) { this.nullableResult = nullableResult; this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null")); + this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null"); this.methodsGroups = ImmutableList.copyOf(requireNonNull(methodsGroups, "methodsWithExtraParametersFunctions is null")); } @@ -244,5 +248,10 @@ List getArgumentProperties() { return argumentProperties; } + + ReturnPlaceConvention getReturnPlaceConvention() + { + return returnPlaceConvention; + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/PolymorphicScalarFunctionBuilder.java b/presto-main/src/main/java/com/facebook/presto/metadata/PolymorphicScalarFunctionBuilder.java index 7bb1bfc64eb3c..e29e6bee2475d 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/PolymorphicScalarFunctionBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/PolymorphicScalarFunctionBuilder.java @@ -15,6 +15,7 @@ import com.facebook.presto.metadata.PolymorphicScalarFunction.PolymorphicScalarFunctionChoice; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -250,6 +251,7 @@ public static class ChoiceBuilder private final Signature signature; private boolean nullableResult; private List argumentProperties; + private ReturnPlaceConvention returnPlaceConvention; private final ImmutableList.Builder methodsGroups = ImmutableList.builder(); private ChoiceBuilder(Class clazz, Signature signature) @@ -264,6 +266,10 @@ public ChoiceBuilder implementation(Function instanceFactory, boolean deterministic) { - this(ImmutableList.of(new ScalarImplementationChoice(nullable, argumentProperties, methodHandle, instanceFactory)), deterministic); + this( + ImmutableList.of(new ScalarImplementationChoice( + nullable, + argumentProperties, + ReturnPlaceConvention.STACK, + methodHandle, + instanceFactory)), + deterministic); } /** @@ -107,6 +114,7 @@ public static class ScalarImplementationChoice { private final boolean nullable; private final List argumentProperties; + private final ReturnPlaceConvention returnPlaceConvention; private final MethodHandle methodHandle; private final Optional instanceFactory; private final boolean hasSession; @@ -114,11 +122,13 @@ public static class ScalarImplementationChoice public ScalarImplementationChoice( boolean nullable, List argumentProperties, + ReturnPlaceConvention returnPlaceConvention, MethodHandle methodHandle, Optional instanceFactory) { this.nullable = nullable; this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null")); + this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null"); this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null"); @@ -158,6 +168,11 @@ public ArgumentProperty getArgumentProperty(int argumentIndex) return argumentProperties.get(argumentIndex); } + public ReturnPlaceConvention getReturnPlaceConvention() + { + return returnPlaceConvention; + } + public MethodHandle getMethodHandle() { return methodHandle; @@ -279,4 +294,9 @@ public enum ArgumentType VALUE_TYPE, FUNCTION_TYPE } + + public enum ReturnPlaceConvention + { + STACK + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java index 44ab8498dd446..320d43e87a25a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ParametricScalarImplementation.java @@ -23,6 +23,7 @@ import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; @@ -169,6 +170,7 @@ public Optional specialize(Signature boundSignatur implementationChoices.add(new ScalarImplementationChoice( choice.nullable, choice.argumentProperties, + choice.returnPlaceConvention, boundMethodHandle.asType(javaMethodType(choice, boundSignature, typeManager)), boundConstructor)); } @@ -317,6 +319,7 @@ public static final class ParametricScalarImplementationChoice { private final boolean nullable; private final List argumentProperties; + private final ReturnPlaceConvention returnPlaceConvention; private final MethodHandle methodHandle; private final Optional constructor; private final List dependencies; @@ -328,6 +331,7 @@ private ParametricScalarImplementationChoice( boolean nullable, boolean hasConnectorSession, List argumentProperties, + ReturnPlaceConvention returnPlaceConvention, MethodHandle methodHandle, Optional constructor, List dependencies, @@ -336,6 +340,7 @@ private ParametricScalarImplementationChoice( this.nullable = nullable; this.hasConnectorSession = hasConnectorSession; this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null")); + this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null"); this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.constructor = requireNonNull(constructor, "constructor is null"); this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null")); @@ -375,6 +380,11 @@ public List getArgumentProperties() return argumentProperties; } + public ReturnPlaceConvention getReturnPlaceConvention() + { + return returnPlaceConvention; + } + public boolean checkDependencies() { for (int i = 1; i < getDependencies().size(); i++) { @@ -514,7 +524,15 @@ else if (actualReturnType.isPrimitive()) { this.methodHandle = getMethodHandle(method); - ParametricScalarImplementationChoice choice = new ParametricScalarImplementationChoice(nullable, hasConnectorSession, argumentProperties, methodHandle, constructorMethodHandle, dependencies, constructorDependencies); + ParametricScalarImplementationChoice choice = new ParametricScalarImplementationChoice( + nullable, + hasConnectorSession, + argumentProperties, + ReturnPlaceConvention.STACK, // TODO: support other return place convention + methodHandle, + constructorMethodHandle, + dependencies, + constructorDependencies); choices.add(choice); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java index decbdad5d144f..8c1590ab0664b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java @@ -17,6 +17,7 @@ import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.BlockBuilder; @@ -181,6 +182,11 @@ public static BytecodeNode generateInvocation(Scope scope, String name, ScalarFu List choices = function.getAllChoices(); ScalarImplementationChoice bestChoice = null; for (ScalarImplementationChoice currentChoice : choices) { + if (currentChoice.getReturnPlaceConvention() != ReturnPlaceConvention.STACK) { + // TODO: support other return place convention + continue; + } + boolean isValid = true; for (int i = 0; i < arguments.size(); i++) { if (currentChoice.getArgumentProperty(i).getArgumentType() != VALUE_TYPE) { From ac0c401eb8cff823e2fa9188afce8689459345b7 Mon Sep 17 00:00:00 2001 From: Wenlei Xie Date: Tue, 15 Aug 2017 15:31:45 -0700 Subject: [PATCH 2/4] Add output block to RowExpressionCompiler.Context --- .../presto/sql/gen/AndCodeGenerator.java | 5 +-- .../sql/gen/BytecodeGeneratorContext.java | 8 ++--- .../presto/sql/gen/CastCodeGenerator.java | 3 +- .../presto/sql/gen/CoalesceCodeGenerator.java | 3 +- .../sql/gen/CursorProcessorCompiler.java | 5 +-- .../sql/gen/DereferenceCodeGenerator.java | 3 +- .../sql/gen/FunctionCallCodeGenerator.java | 4 +-- .../presto/sql/gen/IfCodeGenerator.java | 7 ++-- .../presto/sql/gen/InCodeGenerator.java | 5 +-- .../presto/sql/gen/IsNullCodeGenerator.java | 3 +- .../sql/gen/JoinFilterFunctionCompiler.java | 3 +- .../sql/gen/LambdaBytecodeGenerator.java | 5 +-- .../presto/sql/gen/NullIfCodeGenerator.java | 5 +-- .../presto/sql/gen/OrCodeGenerator.java | 5 +-- .../presto/sql/gen/PageFunctionCompiler.java | 4 +-- .../sql/gen/RowConstructorCodeGenerator.java | 3 +- .../presto/sql/gen/RowExpressionCompiler.java | 32 ++++++++++++++++--- .../presto/sql/gen/SwitchCodeGenerator.java | 9 +++--- 18 files changed, 74 insertions(+), 38 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java index 10316e366a5fa..c57ca80a8f6ca 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java @@ -24,6 +24,7 @@ import io.airlift.bytecode.instruction.LabelNode; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; @@ -40,8 +41,8 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .comment("AND") .setDescription("AND"); - BytecodeNode left = generator.generate(arguments.get(0)); - BytecodeNode right = generator.generate(arguments.get(1)); + BytecodeNode left = generator.generate(arguments.get(0), Optional.empty()); + BytecodeNode right = generator.generate(arguments.get(1), Optional.empty()); block.append(left); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java index 49e572486b008..2e71bdee94623 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java @@ -67,14 +67,14 @@ public CallSiteBinder getCallSiteBinder() return callSiteBinder; } - public BytecodeNode generate(RowExpression expression) + public BytecodeNode generate(RowExpression expression, Optional outputBlockVariable) { - return generate(expression, Optional.empty()); + return generate(expression, outputBlockVariable, Optional.empty()); } - public BytecodeNode generate(RowExpression expression, Optional lambdaInterface) + public BytecodeNode generate(RowExpression expression, Optional outputBlockVariable, Optional lambdaInterface) { - return rowExpressionCompiler.compile(expression, scope, lambdaInterface); + return rowExpressionCompiler.compile(expression, scope, outputBlockVariable, lambdaInterface); } public FunctionRegistry getRegistry() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java index d453df605711d..d3f13b80023aa 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java @@ -20,6 +20,7 @@ import io.airlift.bytecode.BytecodeNode; import java.util.List; +import java.util.Optional; public class CastCodeGenerator implements BytecodeGenerator @@ -33,6 +34,6 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .getRegistry() .getCoercion(argument.getType(), returnType); - return generatorContext.generateCall(function.getName(), generatorContext.getRegistry().getScalarFunctionImplementation(function), ImmutableList.of(generatorContext.generate(argument))); + return generatorContext.generateCall(function.getName(), generatorContext.getRegistry().getScalarFunctionImplementation(function), ImmutableList.of(generatorContext.generate(argument, Optional.empty()))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java index 003c03f7023ba..cbaaa70fe049b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; @@ -36,7 +37,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon { List operands = new ArrayList<>(); for (RowExpression expression : arguments) { - operands.add(generatorContext.generate(expression)); + operands.add(generatorContext.generate(expression, Optional.empty())); } Variable wasNull = generatorContext.wasNull(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java index f7776b5c8bb83..1fedf090ec75e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java @@ -46,6 +46,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.sql.gen.BytecodeUtils.generateWrite; @@ -247,7 +248,7 @@ private void generateFilterMethod( .comment("boolean wasNull = false;") .putVariable(wasNullVariable, false) .comment("evaluate filter: " + filter) - .append(compiler.compile(filter, scope)) + .append(compiler.compile(filter, scope, Optional.empty())) .comment("if (wasNull) return false;") .getVariable(wasNullVariable) .ifFalseGoto(end) @@ -287,7 +288,7 @@ private void generateProjectMethod( .putVariable(wasNullVariable, false) .getVariable(output) .comment("evaluate projection: " + projection.toString()) - .append(compiler.compile(projection, scope)) + .append(compiler.compile(projection, scope, Optional.empty())) .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) .ret(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java index 1d5f75468c91a..c08618de2d32f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java @@ -26,6 +26,7 @@ import io.airlift.bytecode.instruction.LabelNode; import java.util.List; +import java.util.Optional; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static com.google.common.base.Preconditions.checkArgument; @@ -47,7 +48,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon // clear the wasNull flag before evaluating the row value block.putVariable(wasNull, false); - block.append(generator.generate(arguments.get(0))).putVariable(rowBlock); + block.append(generator.generate(arguments.get(0), Optional.empty())).putVariable(rowBlock); IfStatement ifRowBlockIsNull = new IfStatement("if row block is null...") .condition(wasNull); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java index a303d7af92bb6..39c9ae87ff059 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java @@ -41,10 +41,10 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon RowExpression argument = arguments.get(i); ScalarFunctionImplementation.ArgumentProperty argumentProperty = function.getArgumentProperty(i); if (argumentProperty.getArgumentType() == VALUE_TYPE) { - argumentsBytecode.add(context.generate(argument)); + argumentsBytecode.add(context.generate(argument, Optional.empty())); } else { - argumentsBytecode.add(context.generate(argument, Optional.of(argumentProperty.getLambdaInterface()))); + argumentsBytecode.add(context.generate(argument, Optional.empty(), Optional.of(argumentProperty.getLambdaInterface()))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java index 7ab88e12ee295..360cfa65afec1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java @@ -23,6 +23,7 @@ import io.airlift.bytecode.control.IfStatement; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; @@ -36,7 +37,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon Variable wasNull = context.wasNull(); BytecodeBlock condition = new BytecodeBlock() - .append(context.generate(arguments.get(0))) + .append(context.generate(arguments.get(0), Optional.empty())) .comment("... and condition value was not null") .append(wasNull) .invokeStatic(CompilerOperations.class, "not", boolean.class, boolean.class) @@ -45,7 +46,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon return new IfStatement() .condition(condition) - .ifTrue(context.generate(arguments.get(1))) - .ifFalse(context.generate(arguments.get(2))); + .ifTrue(context.generate(arguments.get(1), Optional.empty())) + .ifFalse(context.generate(arguments.get(2), Optional.empty())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java index 4765f8fbd4b5b..dbbc95c6c4c0e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java @@ -40,6 +40,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.spi.function.OperatorType.HASH_CODE; @@ -129,7 +130,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon ImmutableSet.Builder constantValuesBuilder = ImmutableSet.builder(); for (RowExpression testValue : values) { - BytecodeNode testBytecode = generatorContext.generate(testValue); + BytecodeNode testBytecode = generatorContext.generate(testValue, Optional.empty()); if (isDeterminateConstant(testValue, isIndeterminateFunction.getMethodHandle())) { ConstantExpression constant = (ConstantExpression) testValue; @@ -252,7 +253,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon BytecodeBlock block = new BytecodeBlock() .comment("IN") - .append(generatorContext.generate(arguments.get(0))) + .append(generatorContext.generate(arguments.get(0), Optional.empty())) .append(ifWasNullPopAndGoto(scope, end, boolean.class, javaType)) .putVariable(value) .append(switchBlock) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java index f8a44183f8811..1ffc181711209 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java @@ -22,6 +22,7 @@ import io.airlift.bytecode.Variable; import java.util.List; +import java.util.Optional; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; @@ -40,7 +41,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon return loadBoolean(true); } - BytecodeNode value = generatorContext.generate(argument); + BytecodeNode value = generatorContext.generate(argument, Optional.empty()); // evaluate the expression, pop the produced value, and load the null flag Variable wasNull = generatorContext.wasNull(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java index bfad5be6f6141..317676b08cc61 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java @@ -50,6 +50,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.sql.gen.BytecodeUtils.invoke; @@ -198,7 +199,7 @@ private void generateFilterMethod( metadata.getFunctionRegistry(), compiledLambdaMap); - BytecodeNode visitorBody = compiler.compile(filter, scope); + BytecodeNode visitorBody = compiler.compile(filter, scope, Optional.empty()); Variable result = scope.declareVariable(boolean.class, "result"); body.append(visitorBody) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java index e40e8d21541d3..378b8ed0e4691 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java @@ -48,6 +48,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR; @@ -156,7 +157,7 @@ private static CompiledLambda defineLambdaMethod( Scope scope = method.getScope(); Variable wasNull = scope.declareVariable(boolean.class, "wasNull"); - BytecodeNode compiledBody = innerExpressionCompiler.compile(lambda.getBody(), scope); + BytecodeNode compiledBody = innerExpressionCompiler.compile(lambda.getBody(), scope, Optional.empty()); method.getBody() .putVariable(wasNull, false) .append(compiledBody) @@ -197,7 +198,7 @@ public static BytecodeNode generateLambda( for (RowExpression captureExpression : captureExpressions) { Class valueType = Primitives.wrap(captureExpression.getType().getJavaType()); Variable valueVariable = scope.createTempVariable(valueType); - block.append(context.generate(captureExpression)); + block.append(context.generate(captureExpression, Optional.empty())); block.append(boxPrimitiveIfNecessary(scope, valueType)); block.putVariable(valueVariable); block.append(wasNull.set(constantFalse())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java index 885b71a204c80..b13a18974a93a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java @@ -28,6 +28,7 @@ import io.airlift.bytecode.instruction.LabelNode; import java.util.List; +import java.util.Optional; import static com.facebook.presto.sql.gen.BytecodeUtils.ifWasNullPopAndGoto; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; @@ -49,7 +50,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon Variable firstValue = scope.createTempVariable(first.getType().getJavaType()); BytecodeBlock block = new BytecodeBlock() .comment("check if first arg is null") - .append(generatorContext.generate(first)) + .append(generatorContext.generate(first, Optional.empty())) .append(ifWasNullPopAndGoto(scope, notMatch, void.class)) .dup(first.getType().getJavaType()) .putVariable(firstValue); @@ -65,7 +66,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon equalsFunction, ImmutableList.of( cast(generatorContext, firstValue, firstType, equalsSignature.getArgumentTypes().get(0)), - cast(generatorContext, generatorContext.generate(second), secondType, equalsSignature.getArgumentTypes().get(1)))); + cast(generatorContext, generatorContext.generate(second, Optional.empty()), secondType, equalsSignature.getArgumentTypes().get(1)))); BytecodeBlock conditionBlock = new BytecodeBlock() .append(equalsCall) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java index ce911f63734ca..042bbd62a082f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java @@ -24,6 +24,7 @@ import io.airlift.bytecode.instruction.LabelNode; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; @@ -40,8 +41,8 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .comment("OR") .setDescription("OR"); - BytecodeNode left = generator.generate(arguments.get(0)); - BytecodeNode right = generator.generate(arguments.get(1)); + BytecodeNode left = generator.generate(arguments.get(0), Optional.empty()); + BytecodeNode right = generator.generate(arguments.get(1), Optional.empty()); block.append(left); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java index 4d0bf1ef6e030..4be547bb62b1b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java @@ -351,7 +351,7 @@ private MethodDefinition generateEvaluateMethod( compiledLambdaMap); body.append(thisVariable.getField(blockBuilder)) - .append(compiler.compile(projection, scope)) + .append(compiler.compile(projection, scope, Optional.empty())) .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) .ret(); return method; @@ -525,7 +525,7 @@ private MethodDefinition generateFilterMethod( compiledLambdaMap); Variable result = scope.declareVariable(boolean.class, "result"); - body.append(compiler.compile(filter, scope)) + body.append(compiler.compile(filter, scope, Optional.empty())) // store result so we can check for null .putVariable(result) .append(and(not(wasNullVariable), result).ret()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java index 1ed85d55734bb..8867241ed5101 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java @@ -26,6 +26,7 @@ import io.airlift.bytecode.control.IfStatement; import java.util.List; +import java.util.Optional; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; @@ -59,7 +60,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon Variable field = scope.createTempVariable(fieldType.getJavaType()); block.comment("Clean wasNull and Generate + " + i + "-th field of row"); block.append(context.wasNull().set(constantFalse())); - block.append(context.generate(arguments.get(i))); + block.append(context.generate(arguments.get(i), Optional.empty())); block.putVariable(field); block.append(new IfStatement() .condition(context.wasNull()) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java index bd01331648dd8..0c38d1edc27c7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java @@ -27,6 +27,7 @@ import io.airlift.bytecode.BytecodeBlock; import io.airlift.bytecode.BytecodeNode; import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; import java.util.Map; import java.util.Optional; @@ -43,6 +44,7 @@ import static com.facebook.presto.sql.relational.Signatures.NULL_IF; import static com.facebook.presto.sql.relational.Signatures.ROW_CONSTRUCTOR; import static com.facebook.presto.sql.relational.Signatures.SWITCH; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; import static io.airlift.bytecode.instruction.Constant.loadBoolean; @@ -74,14 +76,16 @@ public class RowExpressionCompiler this.compiledLambdaMap = compiledLambdaMap; } - public BytecodeNode compile(RowExpression rowExpression, Scope scope) + public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional outputBlockVariable) { - return compile(rowExpression, scope, Optional.empty()); + return compile(rowExpression, scope, outputBlockVariable, Optional.empty()); } - public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional lambdaInterface) + // When outputBlockVariable is presented, the generated bytecode will write the evaluated value into the outputBlockVariable, + // otherwise the value will be left on stack. + public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional outputBlockVariable, Optional lambdaInterface) { - return rowExpression.accept(new Visitor(), new Context(scope, lambdaInterface)); + return rowExpression.accept(new Visitor(), new Context(scope, outputBlockVariable, lambdaInterface)); } private class Visitor @@ -90,6 +94,9 @@ private class Visitor @Override public BytecodeNode visitCall(CallExpression call, Context context) { + // TODO: support write to output block + checkArgument(!context.getOutputBlockVariable().isPresent(), "call expression does not support writing to block"); + BytecodeGenerator generator; // special-cased in function registry if (call.getSignature().getName().equals(CAST)) { @@ -153,6 +160,9 @@ public BytecodeNode visitCall(CallExpression call, Context context) @Override public BytecodeNode visitConstant(ConstantExpression constant, Context context) { + // TODO: support write to output block + checkArgument(!context.getOutputBlockVariable().isPresent(), "constant expression does not support writing to block"); + Object value = constant.getValue(); Class javaType = constant.getType().getJavaType(); @@ -196,12 +206,15 @@ public BytecodeNode visitConstant(ConstantExpression constant, Context context) @Override public BytecodeNode visitInputReference(InputReferenceExpression node, Context context) { + // TODO: support write to output block + checkArgument(!context.getOutputBlockVariable().isPresent(), "input reference expression does not support writing to block"); return fieldReferenceCompiler.visitInputReference(node, context.getScope()); } @Override public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context context) { + checkArgument(!context.getOutputBlockVariable().isPresent(), "lambda definition expression does not support writing to block"); checkState(compiledLambdaMap.containsKey(lambda), "lambda expressions map does not contain this lambda definition"); if (!context.lambdaInterface.get().isAnnotationPresent(FunctionalInterface.class)) { // lambdaInterface is checked to be annotated with FunctionalInterface when generating ScalarFunctionImplementation @@ -225,6 +238,8 @@ public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context conte @Override public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Context context) { + // TODO: support write to output block + checkArgument(!context.getOutputBlockVariable().isPresent(), "variable reference expression does not support writing to block"); return fieldReferenceCompiler.visitVariableReference(reference, context.getScope()); } } @@ -232,11 +247,13 @@ public BytecodeNode visitVariableReference(VariableReferenceExpression reference private static class Context { private final Scope scope; + private final Optional outputBlockVariable; private final Optional lambdaInterface; - public Context(Scope scope, Optional lambdaInterface) + public Context(Scope scope, Optional outputBlockVariable, Optional lambdaInterface) { this.scope = scope; + this.outputBlockVariable = outputBlockVariable; this.lambdaInterface = lambdaInterface; } @@ -245,6 +262,11 @@ public Scope getScope() return scope; } + public Optional getOutputBlockVariable() + { + return outputBlockVariable; + } + public Optional getLambdaInterface() { return lambdaInterface; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java index 668f8dc2833e4..5311f0df1b89e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java @@ -30,6 +30,7 @@ import io.airlift.bytecode.instruction.VariableInstruction; import java.util.List; +import java.util.Optional; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; @@ -74,7 +75,7 @@ else if ( == ) { // process value, else, and all when clauses RowExpression value = arguments.get(0); - BytecodeNode valueBytecode = generatorContext.generate(value); + BytecodeNode valueBytecode = generatorContext.generate(value, Optional.empty()); BytecodeNode elseValue; List whenClauses; @@ -87,7 +88,7 @@ else if ( == ) { } else { whenClauses = arguments.subList(1, arguments.size() - 1); - elseValue = generatorContext.generate(last); + elseValue = generatorContext.generate(last, Optional.empty()); } // determine the type of the value and result @@ -123,7 +124,7 @@ else if ( == ) { BytecodeNode equalsCall = generatorContext.generateCall( equalsFunction.getName(), generatorContext.getRegistry().getScalarFunctionImplementation(equalsFunction), - ImmutableList.of(generatorContext.generate(operand), getTempVariableNode)); + ImmutableList.of(generatorContext.generate(operand, Optional.empty()), getTempVariableNode)); BytecodeBlock condition = new BytecodeBlock() .append(equalsCall) @@ -131,7 +132,7 @@ else if ( == ) { elseValue = new IfStatement("when") .condition(condition) - .ifTrue(generatorContext.generate(result)) + .ifTrue(generatorContext.generate(result, Optional.empty())) .ifFalse(elseValue); } From 52ab27d965eec291471d45caaa2fc7144b7ba840 Mon Sep 17 00:00:00 2001 From: Wenlei Xie Date: Tue, 15 Aug 2017 21:12:51 -0700 Subject: [PATCH 3/4] Support compiling RowExpression write to output block --- .../presto/sql/gen/AndCodeGenerator.java | 4 +- .../presto/sql/gen/BindCodeGenerator.java | 11 +- .../presto/sql/gen/BytecodeGenerator.java | 14 ++- .../sql/gen/BytecodeGeneratorContext.java | 14 ++- .../presto/sql/gen/BytecodeUtils.java | 73 ++++++++++-- .../presto/sql/gen/CastCodeGenerator.java | 14 ++- .../presto/sql/gen/CoalesceCodeGenerator.java | 8 +- .../sql/gen/CursorProcessorCompiler.java | 5 +- .../sql/gen/DereferenceCodeGenerator.java | 4 +- .../sql/gen/FunctionCallCodeGenerator.java | 10 +- .../presto/sql/gen/IfCodeGenerator.java | 15 ++- .../presto/sql/gen/InCodeGenerator.java | 4 +- .../gen/InvokeFunctionBytecodeExpression.java | 8 +- .../presto/sql/gen/IsNullCodeGenerator.java | 4 +- .../presto/sql/gen/NullIfCodeGenerator.java | 4 +- .../presto/sql/gen/OrCodeGenerator.java | 4 +- .../presto/sql/gen/PageFunctionCompiler.java | 9 +- .../sql/gen/RowConstructorCodeGenerator.java | 4 +- .../presto/sql/gen/RowExpressionCompiler.java | 106 +++++++++++------- .../presto/sql/gen/SwitchCodeGenerator.java | 7 +- 20 files changed, 236 insertions(+), 86 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java index c57ca80a8f6ca..8e90a2d53efe0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java @@ -26,13 +26,14 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; public class AndCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments, Optional outputBlockVariable) { Preconditions.checkArgument(arguments.size() == 2); @@ -98,6 +99,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.append(ifRightIsNull) .visitLabel(end); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generator, returnType, output))); return block; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BindCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BindCodeGenerator.java index 3c7df4a53bdb5..f4c1b0a9f6b56 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BindCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BindCodeGenerator.java @@ -20,10 +20,13 @@ import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.Variable; import java.util.List; import java.util.Map; +import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; public class BindCodeGenerator @@ -39,12 +42,18 @@ public BindCodeGenerator(Map compile } @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments, Optional outputBlockVariable) { // Bind expression is used to generate captured lambda. // It takes the captured values and the uncaptured lambda, and produces captured lambda as the output. // The uncaptured lambda is just a method, and does not have a stack representation during execution. // As a result, the bind expression generates the captured lambda in one step. + + // outputBlockVariable cannot present because + // 1. bind cannot be in the top level of an expression + // 2. lambda cannot be put into blocks. + checkArgument(!outputBlockVariable.isPresent()); + int numCaptures = arguments.size() - 1; LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) arguments.get(numCaptures); checkState(compiledLambdaMap.containsKey(lambda), "lambda expressions map does not contain this lambda definition"); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGenerator.java index 3aa819475dab2..588e6987ada00 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGenerator.java @@ -17,10 +17,22 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.relational.RowExpression; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.Variable; import java.util.List; +import java.util.Optional; public interface BytecodeGenerator { - BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments); + BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments, Optional outputBlockVariable); + + static BytecodeNode generateWrite(BytecodeGeneratorContext context, Type returnType, Variable outputBlock) + { + return BytecodeUtils.generateWrite( + context.getCallSiteBinder(), + context.getScope(), + context.getScope().getVariable("wasNull"), + returnType, + outputBlock); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java index 2e71bdee94623..e96e348df8227 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java @@ -15,6 +15,7 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; +import com.facebook.presto.sql.gen.BytecodeUtils.OutputBlockVariableAndType; import com.facebook.presto.sql.relational.RowExpression; import io.airlift.bytecode.BytecodeNode; import io.airlift.bytecode.FieldDefinition; @@ -82,17 +83,26 @@ public FunctionRegistry getRegistry() return registry; } + public BytecodeNode generateCall(String name, ScalarFunctionImplementation function, List arguments) + { + return generateCall(name, function, arguments, Optional.empty()); + } + /** * Generates a function call with null handling, automatic binding of session parameter, etc. */ - public BytecodeNode generateCall(String name, ScalarFunctionImplementation function, List arguments) + public BytecodeNode generateCall( + String name, + ScalarFunctionImplementation function, + List arguments, + Optional outputBlockVariableAndType) { Optional instance = Optional.empty(); if (function.getInstanceFactory().isPresent()) { FieldDefinition field = cachedInstanceBinder.getCachedInstance(function.getInstanceFactory().get()); instance = Optional.of(scope.getThis().getField(field)); } - return generateInvocation(scope, name, function, instance, arguments, callSiteBinder); + return generateInvocation(scope, name, function, instance, arguments, callSiteBinder, outputBlockVariableAndType); } public Variable wasNull() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java index 8c1590ab0664b..58a00eeedfe6f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java @@ -51,6 +51,7 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; public final class BytecodeUtils { @@ -161,7 +162,32 @@ public static BytecodeExpression loadConstant(Binding binding) binding.getType().returnType()); } - public static BytecodeNode generateInvocation(Scope scope, String name, ScalarFunctionImplementation function, Optional instance, List arguments, CallSiteBinder binder) + public static BytecodeNode generateInvocation( + Scope scope, + String name, + ScalarFunctionImplementation function, + Optional instance, + List arguments, + CallSiteBinder binder) + { + return generateInvocation( + scope, + name, + function, + instance, + arguments, + binder, + Optional.empty()); + } + + public static BytecodeNode generateInvocation( + Scope scope, + String name, + ScalarFunctionImplementation function, + Optional instance, + List arguments, + CallSiteBinder binder, + Optional outputBlockVariableAndType) { LabelNode end = new LabelNode("end"); BytecodeBlock block = new BytecodeBlock() @@ -271,6 +297,9 @@ else if (type == ConnectorSession.class) { } block.visitLabel(end); + if (outputBlockVariableAndType.isPresent()) { + block.append(generateWrite(binder, scope, scope.getVariable("wasNull"), outputBlockVariableAndType.get().getType(), outputBlockVariableAndType.get().getOutputBlockVariable())); + } return block; } @@ -350,7 +379,12 @@ public static BytecodeExpression invoke(Binding binding, Signature signature) return invoke(binding, signature.getName()); } - public static BytecodeNode generateWrite(CallSiteBinder callSiteBinder, Scope scope, Variable wasNullVariable, Type type) + public static BytecodeNode generateWrite( + CallSiteBinder callSiteBinder, + Scope scope, + Variable wasNullVariable, + Type type, + Variable outputBlockVariable) { Class valueJavaType = type.getJavaType(); if (!valueJavaType.isPrimitive() && valueJavaType != Slice.class) { @@ -358,15 +392,8 @@ public static BytecodeNode generateWrite(CallSiteBinder callSiteBinder, Scope sc } String methodName = "write" + Primitives.wrap(valueJavaType).getSimpleName(); - // the stack contains [output, value] - - // We should be able to insert the code to get the output variable and compute the value - // at the right place instead of assuming they are in the stack. We should also not need to - // use temp variables to re-shuffle the stack to the right shape before Type.writeXXX is called - // Unfortunately, because of the assumptions made by try_cast, we can't get around it yet. - // TODO: clean up once try_cast is fixed + // the value to be written is at the top of stack Variable tempValue = scope.createTempVariable(valueJavaType); - Variable tempOutput = scope.createTempVariable(BlockBuilder.class); return new BytecodeBlock() .comment("if (wasNull)") .append(new IfStatement() @@ -374,15 +401,37 @@ public static BytecodeNode generateWrite(CallSiteBinder callSiteBinder, Scope sc .ifTrue(new BytecodeBlock() .comment("output.appendNull();") .pop(valueJavaType) + .getVariable(outputBlockVariable) .invokeInterface(BlockBuilder.class, "appendNull", BlockBuilder.class) .pop()) .ifFalse(new BytecodeBlock() .comment("%s.%s(output, %s)", type.getTypeSignature(), methodName, valueJavaType.getSimpleName()) .putVariable(tempValue) - .putVariable(tempOutput) .append(loadConstant(callSiteBinder.bind(type, Type.class))) - .getVariable(tempOutput) + .getVariable(outputBlockVariable) .getVariable(tempValue) .invokeInterface(Type.class, methodName, void.class, BlockBuilder.class, valueJavaType))); } + + public static class OutputBlockVariableAndType + { + private final Variable outputBlockVariable; + private final Type type; + + public OutputBlockVariableAndType(Variable outputBlockVariable, Type type) + { + this.outputBlockVariable = requireNonNull(outputBlockVariable); + this.type = requireNonNull(type); + } + + public Variable getOutputBlockVariable() + { + return outputBlockVariable; + } + + public Type getType() + { + return type; + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java index d3f13b80023aa..990e2a00d68f0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java @@ -17,16 +17,20 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.relational.RowExpression; import com.google.common.collect.ImmutableList; +import io.airlift.bytecode.BytecodeBlock; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.Variable; import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; + public class CastCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { RowExpression argument = arguments.get(0); @@ -34,6 +38,12 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .getRegistry() .getCoercion(argument.getType(), returnType); - return generatorContext.generateCall(function.getName(), generatorContext.getRegistry().getScalarFunctionImplementation(function), ImmutableList.of(generatorContext.generate(argument, Optional.empty()))); + BytecodeBlock block = new BytecodeBlock() + .append(generatorContext.generateCall( + function.getName(), + generatorContext.getRegistry().getScalarFunctionImplementation(function), + ImmutableList.of(generatorContext.generate(argument, Optional.empty())))); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); + return block; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java index cbaaa70fe049b..2ca48b7868816 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; @@ -33,7 +34,7 @@ public class CoalesceCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { List operands = new ArrayList<>(); for (RowExpression expression : arguments) { @@ -62,6 +63,9 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon nullValue = ifStatement; } - return nullValue; + BytecodeBlock block = new BytecodeBlock() + .append(nullValue); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); + return block; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java index 1fedf090ec75e..c4a18c9ffc7af 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java @@ -49,7 +49,6 @@ import java.util.Optional; import java.util.Set; -import static com.facebook.presto.sql.gen.BytecodeUtils.generateWrite; import static com.facebook.presto.sql.gen.LambdaExpressionExtractor.extractLambdaExpressions; import static io.airlift.bytecode.Access.PUBLIC; import static io.airlift.bytecode.Access.a; @@ -286,10 +285,8 @@ private void generateProjectMethod( method.getBody() .comment("boolean wasNull = false;") .putVariable(wasNullVariable, false) - .getVariable(output) .comment("evaluate projection: " + projection.toString()) - .append(compiler.compile(projection, scope, Optional.empty())) - .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) + .append(compiler.compile(projection, scope, Optional.of(output))) .ret(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java index c08618de2d32f..236f5278527a2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; @@ -36,7 +37,7 @@ public class DereferenceCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments, Optional outputBlockVariable) { checkArgument(arguments.size() == 2); CallSiteBinder callSiteBinder = generator.getCallSiteBinder(); @@ -85,6 +86,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.append(ifFieldIsNull) .visitLabel(end); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generator, returnType, output))); return block; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java index 39c9ae87ff059..5f70d12e8cbb6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java @@ -17,8 +17,10 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.gen.BytecodeUtils.OutputBlockVariableAndType; import com.facebook.presto.sql.relational.RowExpression; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.Variable; import java.util.ArrayList; import java.util.List; @@ -30,7 +32,7 @@ public class FunctionCallCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments, Optional outputBlockVariable) { FunctionRegistry registry = context.getRegistry(); @@ -48,6 +50,10 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon } } - return context.generateCall(signature.getName(), function, argumentsBytecode); + return context.generateCall( + signature.getName(), + function, + argumentsBytecode, + outputBlockVariable.map(variable -> new OutputBlockVariableAndType(variable, returnType))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java index 360cfa65afec1..740fc6e837d38 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java @@ -25,13 +25,13 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; public class IfCodeGenerator implements BytecodeGenerator { - @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments, Optional outputBlockVariable) { Preconditions.checkArgument(arguments.size() == 3); @@ -44,9 +44,12 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .invokeStatic(CompilerOperations.class, "and", boolean.class, boolean.class, boolean.class) .append(wasNull.set(constantFalse())); - return new IfStatement() - .condition(condition) - .ifTrue(context.generate(arguments.get(1), Optional.empty())) - .ifFalse(context.generate(arguments.get(2), Optional.empty())); + BytecodeBlock block = new BytecodeBlock() + .append(new IfStatement() + .condition(condition) + .ifTrue(context.generate(arguments.get(1), Optional.empty())) + .ifFalse(context.generate(arguments.get(2), Optional.empty()))); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(context, returnType, output))); + return block; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java index dbbc95c6c4c0e..cee5106613355 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java @@ -45,6 +45,7 @@ import static com.facebook.presto.spi.function.OperatorType.HASH_CODE; import static com.facebook.presto.spi.function.OperatorType.INDETERMINATE; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static com.facebook.presto.sql.gen.BytecodeUtils.ifWasNullPopAndGoto; import static com.facebook.presto.sql.gen.BytecodeUtils.invoke; import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant; @@ -108,7 +109,7 @@ static SwitchGenerationCase checkSwitchGenerationCase(Type type, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { List values = arguments.subList(1, arguments.size()); // empty IN statements are not allowed by the standard, and not possible here @@ -277,6 +278,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.visitLabel(end); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); return block; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/InvokeFunctionBytecodeExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/InvokeFunctionBytecodeExpression.java index ed50f791c00e0..d02a8f2460421 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/InvokeFunctionBytecodeExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/InvokeFunctionBytecodeExpression.java @@ -65,7 +65,13 @@ private InvokeFunctionBytecodeExpression( { super(type(Primitives.unwrap(function.getMethodHandle().type().returnType()))); - this.invocation = generateInvocation(scope, name, function, instance, parameters.stream().map(BytecodeNode.class::cast).collect(toImmutableList()), binder); + this.invocation = generateInvocation( + scope, + name, + function, + instance, + parameters.stream().map(BytecodeNode.class::cast).collect(toImmutableList()), + binder); this.oneLineDescription = name + "(" + Joiner.on(", ").join(parameters) + ")"; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java index 1ffc181711209..631eff1a74c0c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.instruction.Constant.loadBoolean; @@ -32,7 +33,7 @@ public class IsNullCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { Preconditions.checkArgument(arguments.size() == 1); @@ -54,6 +55,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon // clear the null flag block.append(wasNull.set(constantFalse())); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); return block; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java index b13a18974a93a..23d2a23f54922 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static com.facebook.presto.sql.gen.BytecodeUtils.ifWasNullPopAndGoto; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; @@ -37,7 +38,7 @@ public class NullIfCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { Scope scope = generatorContext.getScope(); @@ -84,6 +85,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .ifTrue(trueBlock) .ifFalse(notMatch)); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); return block; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java index 042bbd62a082f..d7dd173ce61e7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java @@ -26,13 +26,14 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; public class OrCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List arguments, Optional outputBlockVariable) { Preconditions.checkArgument(arguments.size() == 2); @@ -97,6 +98,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.append(ifRightIsNull) .visitLabel(end); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generator, returnType, output))); return block; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java index 4be547bb62b1b..d3a277e65690a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java @@ -70,7 +70,6 @@ import static com.facebook.presto.operator.project.PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters; import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR; -import static com.facebook.presto.sql.gen.BytecodeUtils.generateWrite; import static com.facebook.presto.sql.gen.BytecodeUtils.invoke; import static com.facebook.presto.sql.gen.LambdaExpressionExtractor.extractLambdaExpressions; import static com.facebook.presto.util.CompilerUtils.defineClass; @@ -342,7 +341,7 @@ private MethodDefinition generateEvaluateMethod( declareBlockVariables(projection, page, scope, body); - Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); + scope.declareVariable("wasNull", body, constantFalse()); RowExpressionCompiler compiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, @@ -350,9 +349,9 @@ private MethodDefinition generateEvaluateMethod( metadata.getFunctionRegistry(), compiledLambdaMap); - body.append(thisVariable.getField(blockBuilder)) - .append(compiler.compile(projection, scope, Optional.empty())) - .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) + Variable outputBlockVariable = scope.createTempVariable(BlockBuilder.class); + body.append(outputBlockVariable.set(thisVariable.getField(blockBuilder))) + .append(compiler.compile(projection, scope, Optional.of(outputBlockVariable))) .ret(); return method; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java index 8867241ed5101..756e4e7085021 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; @@ -37,7 +38,7 @@ public class RowConstructorCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type rowType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type rowType, List arguments, Optional outputBlockVariable) { BytecodeBlock block = new BytecodeBlock().setDescription("Constructor for " + rowType.toString()); CallSiteBinder binder = context.getCallSiteBinder(); @@ -72,6 +73,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon block.append(constantType(binder, rowType).invoke("getObject", Object.class, blockBuilder.cast(Block.class), constantInt(0)) .cast(Block.class)); block.append(context.wasNull().set(constantFalse())); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(context, rowType, output))); return block; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java index 0c38d1edc27c7..81e47036950d1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java @@ -32,6 +32,7 @@ import java.util.Map; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeUtils.generateWrite; import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant; import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.generateLambda; import static com.facebook.presto.sql.relational.Signatures.BIND; @@ -94,9 +95,6 @@ private class Visitor @Override public BytecodeNode visitCall(CallExpression call, Context context) { - // TODO: support write to output block - checkArgument(!context.getOutputBlockVariable().isPresent(), "call expression does not support writing to block"); - BytecodeGenerator generator; // special-cased in function registry if (call.getSignature().getName().equals(CAST)) { @@ -154,61 +152,81 @@ public BytecodeNode visitCall(CallExpression call, Context context) cachedInstanceBinder, registry); - return generator.generateExpression(call.getSignature(), generatorContext, call.getType(), call.getArguments()); + return generator.generateExpression(call.getSignature(), generatorContext, call.getType(), call.getArguments(), context.getOutputBlockVariable()); } @Override public BytecodeNode visitConstant(ConstantExpression constant, Context context) { - // TODO: support write to output block - checkArgument(!context.getOutputBlockVariable().isPresent(), "constant expression does not support writing to block"); - Object value = constant.getValue(); Class javaType = constant.getType().getJavaType(); BytecodeBlock block = new BytecodeBlock(); if (value == null) { - return block.comment("constant null") + block.comment("constant null") .append(context.getScope().getVariable("wasNull").set(constantTrue())) .pushJavaDefault(javaType); } + else { + // use LDC for primitives (boolean, short, int, long, float, double) + block.comment("constant " + constant.getType().getTypeSignature()); + if (javaType == boolean.class) { + block.append(loadBoolean((Boolean) value)); + } + else if (javaType == byte.class || javaType == short.class || javaType == int.class) { + block.append(loadInt(((Number) value).intValue())); + } + else if (javaType == long.class) { + block.append(loadLong((Long) value)); + } + else if (javaType == float.class) { + block.append(loadFloat((Float) value)); + } + else if (javaType == double.class) { + block.append(loadDouble((Double) value)); + } + else if (javaType == String.class) { + block.append(loadString((String) value)); + } + else { + // bind constant object directly into the call-site using invoke dynamic + Binding binding = callSiteBinder.bind(value, constant.getType().getJavaType()); - // use LDC for primitives (boolean, short, int, long, float, double) - block.comment("constant " + constant.getType().getTypeSignature()); - if (javaType == boolean.class) { - return block.append(loadBoolean((Boolean) value)); - } - if (javaType == byte.class || javaType == short.class || javaType == int.class) { - return block.append(loadInt(((Number) value).intValue())); - } - if (javaType == long.class) { - return block.append(loadLong((Long) value)); - } - if (javaType == float.class) { - return block.append(loadFloat((Float) value)); - } - if (javaType == double.class) { - return block.append(loadDouble((Double) value)); - } - if (javaType == String.class) { - return block.append(loadString((String) value)); + block = new BytecodeBlock() + .setDescription("constant " + constant.getType()) + .comment(constant.toString()) + .append(loadConstant(binding)); + } } - // bind constant object directly into the call-site using invoke dynamic - Binding binding = callSiteBinder.bind(value, constant.getType().getJavaType()); + if (context.getOutputBlockVariable().isPresent()) { + block.append(generateWrite( + callSiteBinder, + context.getScope(), + context.getScope().getVariable("wasNull"), + constant.getType(), + context.getOutputBlockVariable().get())); + } - return new BytecodeBlock() - .setDescription("constant " + constant.getType()) - .comment(constant.toString()) - .append(loadConstant(binding)); + return block; } @Override public BytecodeNode visitInputReference(InputReferenceExpression node, Context context) { - // TODO: support write to output block - checkArgument(!context.getOutputBlockVariable().isPresent(), "input reference expression does not support writing to block"); - return fieldReferenceCompiler.visitInputReference(node, context.getScope()); + BytecodeNode inputReferenceBytecode = fieldReferenceCompiler.visitInputReference(node, context.getScope()); + if (!context.getOutputBlockVariable().isPresent()) { + return inputReferenceBytecode; + } + + return new BytecodeBlock() + .append(inputReferenceBytecode) + .append(generateWrite( + callSiteBinder, + context.getScope(), + context.getScope().getVariable("wasNull"), + node.getType(), + context.getOutputBlockVariable().get())); } @Override @@ -238,9 +256,19 @@ public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context conte @Override public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Context context) { - // TODO: support write to output block - checkArgument(!context.getOutputBlockVariable().isPresent(), "variable reference expression does not support writing to block"); - return fieldReferenceCompiler.visitVariableReference(reference, context.getScope()); + BytecodeNode variableReferenceByteCode = fieldReferenceCompiler.visitVariableReference(reference, context.getScope()); + if (!context.getOutputBlockVariable().isPresent()) { + return variableReferenceByteCode; + } + + return new BytecodeBlock() + .append(variableReferenceByteCode) + .append(generateWrite( + callSiteBinder, + context.getScope(), + context.getScope().getVariable("wasNull"), + reference.getType(), + context.getOutputBlockVariable().get())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java index 5311f0df1b89e..175bbdfe14edd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java @@ -32,6 +32,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; @@ -39,7 +40,7 @@ public class SwitchCodeGenerator implements BytecodeGenerator { @Override - public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments) + public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext, Type returnType, List arguments, Optional outputBlockVariable) { // TODO: compile as /* @@ -136,6 +137,8 @@ else if ( == ) { .ifFalse(elseValue); } - return block.append(elseValue); + block.append(elseValue); + outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); + return block; } } From 93c7af842e74105e3c24f4a1b205d86b89fc9518 Mon Sep 17 00:00:00 2001 From: Wenlei Xie Date: Thu, 27 Dec 2018 13:52:18 -0800 Subject: [PATCH 4/4] Implement PROVIDED_BLOCKBUILDER return place convention Currently the only return place convention for scalar function is STACK, and the callee will append the value on stack into the result BlockBuilder (for outermost function call) or use it to invoke other functions (for inner function call like `f(g(x))`). For functions returns Slice and Block, this return place convention usually result in copying the data twice -- once generate the data, once copy the data into the output BlockBuilder. This commit implements PROVIDED_BLOCKBUILDER return place convention to allow scalar function implementation choice to directly write to the provided block builder. Similar to the BLOCK_POSITION null convention, a implementation choice used default STACK return place convention must be implemented. Besides STACK return place convention, the developer of the scalar function can choose to provide an additional implementation choice with PROVIDED_BLOCK return place convention. In the future, an invocation adapter should be able to automatically adapt between different calling conventions (null convention and return place convention) when feasible. --- .../scalar/ScalarFunctionImplementation.java | 3 +- .../presto/sql/gen/BytecodeUtils.java | 85 +++- ...idedBlockBuilderReturnPlaceConvention.java | 401 ++++++++++++++++++ 3 files changed, 468 insertions(+), 21 deletions(-) create mode 100644 presto-main/src/test/java/com/facebook/presto/operator/scalar/TestProvidedBlockBuilderReturnPlaceConvention.java diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java index 8abeb74ce5795..219b1bb23fa5e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java @@ -297,6 +297,7 @@ public enum ArgumentType public enum ReturnPlaceConvention { - STACK + STACK, + PROVIDED_BLOCKBUILDER } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java index 58a00eeedfe6f..8bd78ffa6c691 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java @@ -17,7 +17,6 @@ import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention; -import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.BlockBuilder; @@ -43,6 +42,7 @@ import java.util.Optional; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentType.VALUE_TYPE; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention.PROVIDED_BLOCKBUILDER; import static com.facebook.presto.sql.gen.Bootstrap.BOOTSTRAP_METHOD; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -61,25 +61,35 @@ private BytecodeUtils() public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class returnType, Class... stackArgsToPop) { - return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), false); + return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), Optional.empty(), false); } - public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class returnType, Iterable> stackArgsToPop) + public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class methodReturnType, Iterable> stackArgsToPop) { - return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), false); + return handleNullValue(scope, label, methodReturnType, ImmutableList.copyOf(stackArgsToPop), Optional.empty(), false); } - public static BytecodeNode ifWasNullClearPopAndGoto(Scope scope, LabelNode label, Class returnType, Class... stackArgsToPop) + public static BytecodeNode ifWasNullClearPopAndGoto(Scope scope, LabelNode label, Class methodReturnType, Class... stackArgsToPop) { - return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), true); + return handleNullValue(scope, label, methodReturnType, ImmutableList.copyOf(stackArgsToPop), Optional.empty(), true); + } + + public static BytecodeNode ifWasNullClearPopAppendAndGoto(Scope scope, LabelNode label, Class methodReturnType, Variable outputBlockVariable, Iterable> stackArgsToPop) + { + return handleNullValue(scope, label, methodReturnType, ImmutableList.copyOf(stackArgsToPop), Optional.of(outputBlockVariable), true); } public static BytecodeNode handleNullValue(Scope scope, LabelNode label, - Class returnType, + Class methodReturnType, List> stackArgsToPop, + Optional outputBlockVariable, boolean clearNullFlag) { + if (outputBlockVariable.isPresent()) { + checkArgument(methodReturnType == void.class); + } + Variable wasNull = scope.getVariable("wasNull"); BytecodeBlock nullCheck = new BytecodeBlock() @@ -97,9 +107,17 @@ public static BytecodeNode handleNullValue(Scope scope, isNull.pop(parameterType); } - isNull.pushJavaDefault(returnType); - String loadDefaultComment = null; - loadDefaultComment = format("loadJavaDefault(%s)", returnType.getName()); + String loadDefaultOrAppendNullComment; + if (!outputBlockVariable.isPresent()) { + isNull.pushJavaDefault(methodReturnType); + loadDefaultOrAppendNullComment = format("loadJavaDefault(%s)", methodReturnType.getName()); + } + else { + isNull.append(outputBlockVariable.get() + .invoke("appendNull", BlockBuilder.class) + .pop()); + loadDefaultOrAppendNullComment = "appendNullToOutputBlock"; + } isNull.gotoLabel(label); @@ -108,7 +126,7 @@ public static BytecodeNode handleNullValue(Scope scope, popComment = format("pop(%s)", Joiner.on(", ").join(stackArgsToPop)); } - return new IfStatement("if wasNull then %s", Joiner.on(", ").skipNulls().join(clearComment, popComment, loadDefaultComment, "goto " + label.getLabel())) + return new IfStatement("if wasNull then %s", Joiner.on(", ").skipNulls().join(clearComment, popComment, loadDefaultOrAppendNullComment, "goto " + label.getLabel())) .condition(nullCheck) .ifTrue(isNull); } @@ -208,17 +226,13 @@ public static BytecodeNode generateInvocation( List choices = function.getAllChoices(); ScalarImplementationChoice bestChoice = null; for (ScalarImplementationChoice currentChoice : choices) { - if (currentChoice.getReturnPlaceConvention() != ReturnPlaceConvention.STACK) { - // TODO: support other return place convention - continue; - } - boolean isValid = true; for (int i = 0; i < arguments.size(); i++) { if (currentChoice.getArgumentProperty(i).getArgumentType() != VALUE_TYPE) { continue; } - if (!(arguments.get(i) instanceof InputReferenceNode) && currentChoice.getArgumentProperty(i).getNullConvention() == NullConvention.BLOCK_AND_POSITION) { + if (currentChoice.getArgumentProperty(i).getNullConvention() == NullConvention.BLOCK_AND_POSITION && !(arguments.get(i) instanceof InputReferenceNode) + || currentChoice.getReturnPlaceConvention() == PROVIDED_BLOCKBUILDER && (!outputBlockVariableAndType.isPresent())) { isValid = false; break; } @@ -247,6 +261,9 @@ public static BytecodeNode generateInvocation( else if (type == ConnectorSession.class) { block.append(scope.getVariable("session")); } + else if (type == BlockBuilder.class) { + block.append(outputBlockVariableAndType.get().getOutputBlockVariable()); + } else { ArgumentProperty argumentProperty = bestChoice.getArgumentProperty(realParameterIndex); switch (argumentProperty.getArgumentType()) { @@ -256,7 +273,17 @@ else if (type == ConnectorSession.class) { case RETURN_NULL_ON_NULL: block.append(arguments.get(realParameterIndex)); checkArgument(!Primitives.isWrapperType(type), "Non-nullable argument must not be primitive wrapper type"); - block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes))); + switch (bestChoice.getReturnPlaceConvention()) { + case STACK: + block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes))); + break; + case PROVIDED_BLOCKBUILDER: + checkArgument(unboxedReturnType == void.class); + block.append(ifWasNullClearPopAppendAndGoto(scope, end, unboxedReturnType, outputBlockVariableAndType.get().getOutputBlockVariable(), Lists.reverse(stackTypes))); + break; + default: + throw new UnsupportedOperationException(format("Unsupported return place convention: %s", bestChoice.getReturnPlaceConvention())); + } break; case USE_NULL_FLAG: block.append(arguments.get(realParameterIndex)); @@ -293,12 +320,30 @@ else if (type == ConnectorSession.class) { block.append(invoke(binding, name)); if (function.isNullable()) { - block.append(unboxPrimitiveIfNecessary(scope, returnType)); + switch (bestChoice.getReturnPlaceConvention()) { + case STACK: + block.append(unboxPrimitiveIfNecessary(scope, returnType)); + break; + case PROVIDED_BLOCKBUILDER: + // no-op + break; + default: + throw new UnsupportedOperationException(format("Unsupported return place convention: %s", bestChoice.getReturnPlaceConvention())); + } } block.visitLabel(end); if (outputBlockVariableAndType.isPresent()) { - block.append(generateWrite(binder, scope, scope.getVariable("wasNull"), outputBlockVariableAndType.get().getType(), outputBlockVariableAndType.get().getOutputBlockVariable())); + switch (bestChoice.getReturnPlaceConvention()) { + case STACK: + block.append(generateWrite(binder, scope, scope.getVariable("wasNull"), outputBlockVariableAndType.get().getType(), outputBlockVariableAndType.get().getOutputBlockVariable())); + break; + case PROVIDED_BLOCKBUILDER: + // no-op + break; + default: + throw new UnsupportedOperationException(format("Unsupported return place convention: %s", bestChoice.getReturnPlaceConvention())); + } } return block; } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestProvidedBlockBuilderReturnPlaceConvention.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestProvidedBlockBuilderReturnPlaceConvention.java new file mode 100644 index 0000000000000..5e601ba961e0f --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestProvidedBlockBuilderReturnPlaceConvention.java @@ -0,0 +1,401 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionKind; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; + +import static com.facebook.presto.metadata.Signature.typeVariable; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_BOXED_TYPE; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention.PROVIDED_BLOCKBUILDER; +import static com.facebook.presto.operator.scalar.TestProvidedBlockBuilderReturnPlaceConvention.FunctionWithProvidedBlockReturnPlaceConvention1.PROVIDED_BLOCKBUILDER_CONVENTION1; +import static com.facebook.presto.operator.scalar.TestProvidedBlockBuilderReturnPlaceConvention.FunctionWithProvidedBlockReturnPlaceConvention2.PROVIDED_BLOCKBUILDER_CONVENTION2; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.type.VarcharType.createVarcharType; +import static com.facebook.presto.util.Reflection.methodHandle; +import static com.google.common.primitives.Primitives.wrap; +import static org.testng.Assert.assertTrue; + +public class TestProvidedBlockBuilderReturnPlaceConvention + extends AbstractTestFunctions +{ + @BeforeClass + public void setUp() + { + registerScalarFunction(PROVIDED_BLOCKBUILDER_CONVENTION1); + registerScalarFunction(PROVIDED_BLOCKBUILDER_CONVENTION2); + } + + @Test + public void testProvidedBlockBuilderReturnNullOnNull() + { + assertFunction("identity1(123)", INTEGER, 123); + assertFunction("identity1(identity1(123))", INTEGER, 123); + assertFunction("identity1(CAST(null AS INTEGER))", INTEGER, null); + assertFunction("identity1(identity1(CAST(null AS INTEGER)))", INTEGER, null); + + assertFunction("identity1(123.4E0)", DOUBLE, 123.4); + assertFunction("identity1(identity1(123.4E0))", DOUBLE, 123.4); + assertFunction("identity1(CAST(null AS DOUBLE))", DOUBLE, null); + assertFunction("identity1(identity1(CAST(null AS DOUBLE)))", DOUBLE, null); + + assertFunction("identity1(true)", BOOLEAN, true); + assertFunction("identity1(identity1(true))", BOOLEAN, true); + assertFunction("identity1(CAST(null AS BOOLEAN))", BOOLEAN, null); + assertFunction("identity1(identity1(CAST(null AS BOOLEAN)))", BOOLEAN, null); + + assertFunction("identity1('abc')", createVarcharType(3), "abc"); + assertFunction("identity1(identity1('abc'))", createVarcharType(3), "abc"); + assertFunction("identity1(CAST(null AS VARCHAR))", VARCHAR, null); + assertFunction("identity1(identity1(CAST(null AS VARCHAR)))", VARCHAR, null); + + assertFunction("identity1(ARRAY[1,2,3])", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("identity1(identity1(ARRAY[1,2,3]))", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("identity1(CAST(null AS ARRAY))", new ArrayType(INTEGER), null); + assertFunction("identity1(identity1(CAST(null AS ARRAY)))", new ArrayType(INTEGER), null); + + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderLong.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderDouble.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderBoolean.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderSlice.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention1.hitProvidedBlockBuilderBlock.get() > 0); + } + + @Test + public void testProvidedBlockBuilderUseBoxedType() + { + assertFunction("identity2(123)", INTEGER, 123); + assertFunction("identity2(identity2(123))", INTEGER, 123); + assertFunction("identity2(CAST(null AS INTEGER))", INTEGER, null); + assertFunction("identity2(identity2(CAST(null AS INTEGER)))", INTEGER, null); + + assertFunction("identity2(123.4E0)", DOUBLE, 123.4); + assertFunction("identity2(identity2(123.4E0))", DOUBLE, 123.4); + assertFunction("identity2(CAST(null AS DOUBLE))", DOUBLE, null); + assertFunction("identity2(identity2(CAST(null AS DOUBLE)))", DOUBLE, null); + + assertFunction("identity2(true)", BOOLEAN, true); + assertFunction("identity2(identity2(true))", BOOLEAN, true); + assertFunction("identity2(CAST(null AS BOOLEAN))", BOOLEAN, null); + assertFunction("identity2(identity2(CAST(null AS BOOLEAN)))", BOOLEAN, null); + + assertFunction("identity2('abc')", createVarcharType(3), "abc"); + assertFunction("identity2(identity2('abc'))", createVarcharType(3), "abc"); + assertFunction("identity2(CAST(null AS VARCHAR))", VARCHAR, null); + assertFunction("identity2(identity2(CAST(null AS VARCHAR)))", VARCHAR, null); + + assertFunction("identity2(ARRAY[1,2,3])", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("identity2(identity2(ARRAY[1,2,3]))", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("identity2(CAST(null AS ARRAY))", new ArrayType(INTEGER), null); + assertFunction("identity2(identity2(CAST(null AS ARRAY)))", new ArrayType(INTEGER), null); + + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderLong.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderDouble.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderBoolean.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderSlice.get() > 0); + assertTrue(FunctionWithProvidedBlockReturnPlaceConvention2.hitProvidedBlockBuilderBlock.get() > 0); + } + + // null convention RETURN_NULL_ON_NULL + public static class FunctionWithProvidedBlockReturnPlaceConvention1 + extends SqlScalarFunction + { + private static final AtomicLong hitProvidedBlockBuilderLong = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderDouble = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderBoolean = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderSlice = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderBlock = new AtomicLong(); + + public static final FunctionWithProvidedBlockReturnPlaceConvention1 PROVIDED_BLOCKBUILDER_CONVENTION1 = new FunctionWithProvidedBlockReturnPlaceConvention1(); + + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_LONG = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockLong", Type.class, BlockBuilder.class, long.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_DOUBLE = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockDouble", Type.class, BlockBuilder.class, double.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_BOOLEAN = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockBoolean", Type.class, BlockBuilder.class, boolean.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_SLICE = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockSlice", Type.class, BlockBuilder.class, Slice.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_BLOCK = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention1.class, "providedBlockBlock", Type.class, BlockBuilder.class, Block.class); + + protected FunctionWithProvidedBlockReturnPlaceConvention1() + { + super(new Signature( + "identity1", + FunctionKind.SCALAR, + ImmutableList.of(typeVariable("T")), + ImmutableList.of(), + parseTypeSignature("T"), + ImmutableList.of(parseTypeSignature("T")), + false)); + } + + @Override + public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) + { + Type type = boundVariables.getTypeVariable("T"); + MethodHandle methodHandleStack = MethodHandles.identity(type.getJavaType()); + MethodHandle methodHandleProvidedBlock; + if (type.getJavaType() == long.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_LONG.bindTo(type); + } + else if (type.getJavaType() == double.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_DOUBLE.bindTo(type); + } + else if (type.getJavaType() == boolean.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_BOOLEAN.bindTo(type); + } + else if (type.getJavaType() == Slice.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_SLICE.bindTo(type); + } + else if (type.getJavaType() == Block.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_BLOCK.bindTo(type); + } + else { + throw new UnsupportedOperationException(); + } + + return new ScalarFunctionImplementation( + ImmutableList.of( + new ScalarImplementationChoice( + false, + ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), + ReturnPlaceConvention.STACK, + methodHandleStack, + Optional.empty()), + new ScalarImplementationChoice( + false, + ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), + PROVIDED_BLOCKBUILDER, + methodHandleProvidedBlock, + Optional.empty())), + isDeterministic()); + } + + public static void providedBlockLong(Type type, BlockBuilder output, long value) + { + hitProvidedBlockBuilderLong.incrementAndGet(); + type.writeLong(output, value); + } + + public static void providedBlockDouble(Type type, BlockBuilder output, double value) + { + hitProvidedBlockBuilderDouble.incrementAndGet(); + type.writeDouble(output, value); + } + + public static void providedBlockBoolean(Type type, BlockBuilder output, boolean value) + { + hitProvidedBlockBuilderBoolean.incrementAndGet(); + type.writeBoolean(output, value); + } + + public static void providedBlockSlice(Type type, BlockBuilder output, Slice value) + { + hitProvidedBlockBuilderSlice.incrementAndGet(); + type.writeSlice(output, value); + } + + public static void providedBlockBlock(Type type, BlockBuilder output, Block value) + { + hitProvidedBlockBuilderBlock.incrementAndGet(); + type.writeObject(output, value); + } + + @Override + public boolean isDeterministic() + { + return true; + } + + @Override + public boolean isHidden() + { + return false; + } + + @Override + public String getDescription() + { + return ""; + } + } + + // null convention USE_BOXED_TYPE + public static class FunctionWithProvidedBlockReturnPlaceConvention2 + extends SqlScalarFunction + { + private static final AtomicLong hitProvidedBlockBuilderLong = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderDouble = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderBoolean = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderSlice = new AtomicLong(); + private static final AtomicLong hitProvidedBlockBuilderBlock = new AtomicLong(); + + public static final FunctionWithProvidedBlockReturnPlaceConvention2 PROVIDED_BLOCKBUILDER_CONVENTION2 = new FunctionWithProvidedBlockReturnPlaceConvention2(); + + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_LONG = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockLong", Type.class, BlockBuilder.class, Long.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_DOUBLE = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockDouble", Type.class, BlockBuilder.class, Double.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_BOOLEAN = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockBoolean", Type.class, BlockBuilder.class, Boolean.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_SLICE = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockSlice", Type.class, BlockBuilder.class, Slice.class); + private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK_BLOCK = methodHandle(FunctionWithProvidedBlockReturnPlaceConvention2.class, "providedBlockBlock", Type.class, BlockBuilder.class, Block.class); + + protected FunctionWithProvidedBlockReturnPlaceConvention2() + { + super(new Signature( + "identity2", + FunctionKind.SCALAR, + ImmutableList.of(typeVariable("T")), + ImmutableList.of(), + parseTypeSignature("T"), + ImmutableList.of(parseTypeSignature("T")), + false)); + } + + @Override + public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) + { + Type type = boundVariables.getTypeVariable("T"); + MethodHandle methodHandleStack = MethodHandles.identity(wrap(type.getJavaType())); + MethodHandle methodHandleProvidedBlock; + if (type.getJavaType() == long.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_LONG.bindTo(type); + } + else if (type.getJavaType() == double.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_DOUBLE.bindTo(type); + } + else if (type.getJavaType() == boolean.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_BOOLEAN.bindTo(type); + } + else if (type.getJavaType() == Slice.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_SLICE.bindTo(type); + } + else if (type.getJavaType() == Block.class) { + methodHandleProvidedBlock = METHOD_HANDLE_PROVIDED_BLOCK_BLOCK.bindTo(type); + } + else { + throw new UnsupportedOperationException(); + } + + return new ScalarFunctionImplementation( + ImmutableList.of( + new ScalarImplementationChoice( + true, + ImmutableList.of(valueTypeArgumentProperty(USE_BOXED_TYPE)), + ReturnPlaceConvention.STACK, + methodHandleStack, + Optional.empty()), + new ScalarImplementationChoice( + true, + ImmutableList.of(valueTypeArgumentProperty(USE_BOXED_TYPE)), + PROVIDED_BLOCKBUILDER, + methodHandleProvidedBlock, + Optional.empty())), + isDeterministic()); + } + + public static void providedBlockLong(Type type, BlockBuilder output, Long value) + { + hitProvidedBlockBuilderLong.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeLong(output, value); + } + } + + public static void providedBlockDouble(Type type, BlockBuilder output, Double value) + { + hitProvidedBlockBuilderDouble.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeDouble(output, value); + } + } + + public static void providedBlockBoolean(Type type, BlockBuilder output, Boolean value) + { + hitProvidedBlockBuilderBoolean.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeBoolean(output, value); + } + } + + public static void providedBlockSlice(Type type, BlockBuilder output, Slice value) + { + hitProvidedBlockBuilderSlice.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeSlice(output, value); + } + } + + public static void providedBlockBlock(Type type, BlockBuilder output, Block value) + { + hitProvidedBlockBuilderBlock.incrementAndGet(); + if (value == null) { + output.appendNull(); + } + else { + type.writeObject(output, value); + } + } + + @Override + public boolean isDeterministic() + { + return true; + } + + @Override + public boolean isHidden() + { + return false; + } + + @Override + public String getDescription() + { + return ""; + } + } +}