diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java index 8081b96249aac..081f448a4c60c 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java @@ -214,6 +214,8 @@ import static com.facebook.presto.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION; import static com.facebook.presto.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR; import static com.facebook.presto.operator.scalar.ArrayFlattenFunction.ARRAY_FLATTEN_FUNCTION; +import static com.facebook.presto.operator.scalar.ArrayIdentityDirectFunction.ARRAY_IDENTITY_DIRECT_FUNCTION; +import static com.facebook.presto.operator.scalar.ArrayIdentityFunction.ARRAY_IDENTITY_FUNCTION; import static com.facebook.presto.operator.scalar.ArrayJoin.ARRAY_JOIN; import static com.facebook.presto.operator.scalar.ArrayJoin.ARRAY_JOIN_WITH_NULL_REPLACEMENT; import static com.facebook.presto.operator.scalar.ArrayReduceFunction.ARRAY_REDUCE_FUNCTION; @@ -546,6 +548,7 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key) .function(MAP_ELEMENT_AT) .function(MAP_CONCAT_FUNCTION) .function(ARRAY_FLATTEN_FUNCTION) + .functions(ARRAY_IDENTITY_DIRECT_FUNCTION, ARRAY_IDENTITY_FUNCTION) .function(ARRAY_CONCAT_FUNCTION) .functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_TO_JSON, JSON_TO_ARRAY, JSON_STRING_TO_ARRAY) .functions(new MapSubscriptOperator(featuresConfig.isLegacyMapSubscript())) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java index 8746f39882049..a084c14dd8c5a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java @@ -101,6 +101,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in nCopies(arity, Optional.empty()), methodHandleAndConstructor.getMethodHandle(), Optional.of(methodHandleAndConstructor.getConstructor()), + false, isDeterministic()); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIdentityDirectFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIdentityDirectFunction.java new file mode 100644 index 0000000000000..bfd6dd3f95840 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIdentityDirectFunction.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionKind; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.google.common.collect.ImmutableList; + +import java.lang.invoke.MethodHandle; +import java.util.Optional; + +import static com.facebook.presto.metadata.Signature.typeVariable; +import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.util.Reflection.methodHandle; + +public class ArrayIdentityDirectFunction + extends SqlScalarFunction +{ + public static final ArrayIdentityDirectFunction ARRAY_IDENTITY_DIRECT_FUNCTION = new ArrayIdentityDirectFunction(); + private static final String FUNCTION_NAME = "array_identity_direct"; + private static final MethodHandle METHOD_HANDLE = methodHandle(ArrayIdentityDirectFunction.class, "identity", Type.class, Type.class, BlockBuilder.class, Block.class); + + private ArrayIdentityDirectFunction() + { + super(new Signature(FUNCTION_NAME, + FunctionKind.SCALAR, + ImmutableList.of(typeVariable("E")), + ImmutableList.of(), + parseTypeSignature("array(E)"), + ImmutableList.of(parseTypeSignature("array(E)")), + false)); + } + + @Override + public boolean isHidden() + { + return false; + } + + @Override + public boolean isDeterministic() + { + return true; + } + + @Override + public String getDescription() + { + return "array identity"; + } + + @Override + public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) + { + Type elementType = boundVariables.getTypeVariable("E"); + Type arrayType = typeManager.getParameterizedType(StandardTypes.ARRAY, ImmutableList.of(TypeSignatureParameter.of(elementType.getTypeSignature()))); + MethodHandle methodHandle = METHOD_HANDLE.bindTo(elementType).bindTo(arrayType); + return new ScalarFunctionImplementation( + false, + ImmutableList.of(false), + ImmutableList.of(false), + ImmutableList.of(Optional.empty()), + methodHandle, + Optional.empty(), + true, + isDeterministic()); + } + + public static void identity(Type type, Type arrayType, BlockBuilder outputBlock, Block array) + { + BlockBuilder entryBlockBuilder = outputBlock.beginBlockEntry(); + for (int i = 0; i < array.getPositionCount(); i++) { + if (array.isNull(i)) { + entryBlockBuilder.appendNull(); + } + else { + type.appendTo(array, i, entryBlockBuilder); + } + } + outputBlock.closeEntry(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIdentityFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIdentityFunction.java new file mode 100644 index 0000000000000..641153641702c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIdentityFunction.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.annotation.UsedByGeneratedCode; +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionKind; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.spi.PageBuilder; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.google.common.collect.ImmutableList; + +import java.lang.invoke.MethodHandle; +import java.util.Optional; + +import static com.facebook.presto.metadata.Signature.typeVariable; +import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.util.Reflection.methodHandle; + +public class ArrayIdentityFunction + extends SqlScalarFunction +{ + public static final ArrayIdentityFunction ARRAY_IDENTITY_FUNCTION = new ArrayIdentityFunction(); + private static final String FUNCTION_NAME = "array_identity"; + + private static final MethodHandle METHOD_HANDLE = methodHandle(ArrayIdentityFunction.class, "identity", Type.class, Type.class, Object.class, Block.class); + private static final MethodHandle STATE_FACTORY = methodHandle(ArrayIdentityFunction.class, "createState", ArrayType.class); + + private ArrayIdentityFunction() + { + super(new Signature(FUNCTION_NAME, + FunctionKind.SCALAR, + ImmutableList.of(typeVariable("E")), + ImmutableList.of(), + parseTypeSignature("array(E)"), + ImmutableList.of(parseTypeSignature("array(E)")), + false)); + } + + @Override + public boolean isHidden() + { + return false; + } + + @Override + public boolean isDeterministic() + { + return true; + } + + @Override + public String getDescription() + { + return "array identity"; + } + + @Override + public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) + { + Type elementType = boundVariables.getTypeVariable("E"); + Type arrayType = typeManager.getParameterizedType(StandardTypes.ARRAY, ImmutableList.of(TypeSignatureParameter.of(elementType.getTypeSignature()))); + MethodHandle methodHandle = METHOD_HANDLE.bindTo(elementType).bindTo(arrayType); + return new ScalarFunctionImplementation( + false, + ImmutableList.of(false), + ImmutableList.of(false), + ImmutableList.of(Optional.empty()), + methodHandle, + Optional.of(STATE_FACTORY.bindTo(arrayType)), + false, + isDeterministic()); + } + + @UsedByGeneratedCode + public static Object createState(ArrayType arrayType) + { + return new PageBuilder(ImmutableList.of(arrayType)); + } + + public static Block identity(Type type, Type arrayType, Object state, Block array) + { + PageBuilder pageBuilder = (PageBuilder) state; + if (pageBuilder.isFull()) { + pageBuilder.reset(); + } + BlockBuilder arrayBlockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder blockBuilder = arrayBlockBuilder.beginBlockEntry(); + + for (int i = 0; i < array.getPositionCount(); i++) { + if (array.isNull(i)) { + blockBuilder.appendNull(); + } + else { + type.appendTo(array, i, blockBuilder); + } + } + + arrayBlockBuilder.closeEntry(); + pageBuilder.declarePosition(); + return (Block) arrayType.getObject(arrayBlockBuilder, arrayBlockBuilder.getPositionCount() - 1); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java index 3af984b239ad3..716db500ca453 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java @@ -111,6 +111,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in ImmutableList.of(Optional.empty(), Optional.of(UnaryFunctionInterface.class)), methodHandle(generatedClass, "transform", PageBuilder.class, Block.class, UnaryFunctionInterface.class), Optional.of(methodHandle(generatedClass, "createPageBuilder")), + false, isDeterministic()); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java index dca56aa4c4241..11c7420f1c66e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java @@ -110,6 +110,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in nCopies(arity, Optional.empty()), methodHandleAndConstructor.getMethodHandle(), Optional.of(methodHandleAndConstructor.getConstructor()), + false, isDeterministic()); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java index ac12e9690e881..e043db114ed4f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java @@ -98,6 +98,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in ImmutableList.of(Optional.empty(), Optional.empty()), METHOD_HANDLE.bindTo(mapType).bindTo(keyEqual).bindTo(keyHashCode), Optional.of(instanceFactory), + false, isDeterministic()); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java index 7a023f2cf8310..a049822d5e821 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java @@ -119,6 +119,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)), generateFilter(mapType), Optional.of(STATE_FACTORY.bindTo(mapType)), + false, isDeterministic()); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java index 9e36d1e075a68..fd1e0302dc1c9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java @@ -130,6 +130,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)), generateTransformKey(keyType, transformedKeyType, valueType, resultMapType), Optional.of(STATE_FACTORY.bindTo(resultMapType)), + false, isDeterministic()); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java index 50e7c82d1d7d5..7263b231a96ab 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java @@ -125,6 +125,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)), generateTransform(keyType, valueType, transformedValueType, resultMapType), Optional.of(STATE_FACTORY.bindTo(resultMapType)), + false, isDeterministic()); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index b20e76a2d7bc7..e835abd144f5d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -89,6 +89,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in implementation.getLambdaInterface(), methodHandleAndConstructor.get().getMethodHandle(), methodHandleAndConstructor.get().getConstructor(), + false, isDeterministic()); } @@ -104,6 +105,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in implementation.getLambdaInterface(), methodHandle.get().getMethodHandle(), methodHandle.get().getConstructor(), + false, isDeterministic()); } } @@ -122,6 +124,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in implementation.getLambdaInterface(), methodHandle.get().getMethodHandle(), methodHandle.get().getConstructor(), + false, isDeterministic()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java index f938fa0b1b14d..c50ddd1c293ac 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java @@ -33,6 +33,7 @@ public final class ScalarFunctionImplementation private final List> lambdaInterface; private final MethodHandle methodHandle; private final Optional instanceFactory; + private final boolean writeToBlockBuilderParamater; private final boolean deterministic; public ScalarFunctionImplementation(boolean nullable, List nullableArguments, MethodHandle methodHandle, boolean deterministic) @@ -44,6 +45,7 @@ public ScalarFunctionImplementation(boolean nullable, List nullableArgu nCopies(nullableArguments.size(), Optional.empty()), methodHandle, Optional.empty(), + false, deterministic); } @@ -56,6 +58,7 @@ public ScalarFunctionImplementation(boolean nullable, List nullableArgu nCopies(nullableArguments.size(), Optional.empty()), methodHandle, Optional.empty(), + false, deterministic); } @@ -74,6 +77,7 @@ public ScalarFunctionImplementation( lambdaInterface, methodHandle, Optional.empty(), + false, deterministic); } @@ -84,6 +88,7 @@ public ScalarFunctionImplementation( List> lambdaInterface, MethodHandle methodHandle, Optional instanceFactory, + boolean writeToBlockBuilderParamater, boolean deterministic) { this.nullable = nullable; @@ -92,6 +97,7 @@ public ScalarFunctionImplementation( this.lambdaInterface = ImmutableList.copyOf(requireNonNull(lambdaInterface, "lambdaInterface is null")); this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null"); + this.writeToBlockBuilderParamater = writeToBlockBuilderParamater; this.deterministic = deterministic; if (instanceFactory.isPresent()) { @@ -146,6 +152,11 @@ public Optional getInstanceFactory() return instanceFactory; } + public boolean isWriteToBlockBuilderParamater() + { + return writeToBlockBuilderParamater; + } + public boolean isDeterministic() { return deterministic; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/FunctionInvoker.java b/presto-main/src/main/java/com/facebook/presto/sql/FunctionInvoker.java index df8cfccd842f2..dde910863ff57 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/FunctionInvoker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/FunctionInvoker.java @@ -17,6 +17,10 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.Type; import com.google.common.base.Defaults; import com.google.common.base.Throwables; @@ -26,6 +30,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static java.lang.invoke.MethodHandleProxies.asInterfaceInstance; import static java.util.Objects.requireNonNull; @@ -80,7 +85,16 @@ public Object invoke(Signature function, ConnectorSession session, List } try { - return method.invokeWithArguments(actualArguments); + if (implementation.isWriteToBlockBuilderParamater()) { + Type arrayType = new ArrayType(INTEGER); + BlockBuilder builder = arrayType.createBlockBuilder(new BlockBuilderStatus(), 20); + actualArguments.add(0, builder); + method.invokeWithArguments(actualArguments); + return arrayType.getObject(builder, builder.getPositionCount() - 1); + } + else { + return method.invokeWithArguments(actualArguments); + } } catch (Throwable throwable) { throw Throwables.propagate(throwable); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java index 0a19c9c611365..1a80729210fb1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/AndCodeGenerator.java @@ -40,8 +40,8 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .comment("AND") .setDescription("AND"); - BytecodeNode left = generator.generate(arguments.get(0)); - BytecodeNode right = generator.generate(arguments.get(1)); + BytecodeNode left = generator.generate(arguments.get(0), generator.getOutputBlockBuilder()); + BytecodeNode right = generator.generate(arguments.get(1), generator.getOutputBlockBuilder()); block.append(left); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java index 4abd084e2f3dc..d74c987e6a198 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java @@ -35,6 +35,7 @@ public class BytecodeGeneratorContext private final CachedInstanceBinder cachedInstanceBinder; private final FunctionRegistry registry; private final PreGeneratedExpressions preGeneratedExpressions; + private final Optional outputBlockBuilder; private final Variable wasNull; public BytecodeGeneratorContext( @@ -43,6 +44,7 @@ public BytecodeGeneratorContext( CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, FunctionRegistry registry, + Optional outputBlockBuilder, PreGeneratedExpressions preGeneratedExpressions) { requireNonNull(rowExpressionCompiler, "bytecodeGenerator is null"); @@ -57,6 +59,7 @@ public BytecodeGeneratorContext( this.cachedInstanceBinder = cachedInstanceBinder; this.registry = registry; this.preGeneratedExpressions = preGeneratedExpressions; + this.outputBlockBuilder = outputBlockBuilder; this.wasNull = scope.getVariable("wasNull"); } @@ -70,14 +73,14 @@ public CallSiteBinder getCallSiteBinder() return callSiteBinder; } - public BytecodeNode generate(RowExpression expression) + public BytecodeNode generate(RowExpression expression, Optional outputBlockBuilder) { - return generate(expression, Optional.empty()); + return generate(expression, outputBlockBuilder, Optional.empty()); } - public BytecodeNode generate(RowExpression expression, Optional lambdaInterface) + public BytecodeNode generate(RowExpression expression, Optional outputBlockBuilder, Optional lambdaInterface) { - return rowExpressionCompiler.compile(expression, scope, lambdaInterface); + return rowExpressionCompiler.compile(expression, scope, outputBlockBuilder, lambdaInterface); } public FunctionRegistry getRegistry() @@ -85,6 +88,11 @@ public FunctionRegistry getRegistry() return registry; } + public Optional getOutputBlockBuilder() + { + return outputBlockBuilder; + } + /** * Generates a function call with null handling, automatic binding of session parameter, etc. */ @@ -96,7 +104,7 @@ public BytecodeNode generateCall(String name, ScalarFunctionImplementation funct FieldDefinition field = cachedInstanceBinder.getCachedInstance(function.getInstanceFactory().get()); instance = Optional.of(scope.getThis().getField(field)); } - return generateInvocation(scope, name, function, instance, arguments, binding); + return generateInvocation(scope, name, function, instance, outputBlockBuilder, arguments, binding); } public Variable wasNull() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java index d3f56ea994cde..4381ec112c737 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java @@ -20,16 +20,20 @@ import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.bytecode.expression.BytecodeExpression; import com.facebook.presto.bytecode.instruction.LabelNode; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.relational.CallExpression; +import com.facebook.presto.sql.relational.RowExpression; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.primitives.Primitives; import io.airlift.slice.Slice; +import sun.reflect.generics.reflectiveObjects.NotImplementedException; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; @@ -42,6 +46,17 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeDynamic; import static com.facebook.presto.sql.gen.Bootstrap.BOOTSTRAP_METHOD; +import static com.facebook.presto.sql.relational.Signatures.BIND; +import static com.facebook.presto.sql.relational.Signatures.CAST; +import static com.facebook.presto.sql.relational.Signatures.COALESCE; +import static com.facebook.presto.sql.relational.Signatures.DEREFERENCE; +import static com.facebook.presto.sql.relational.Signatures.IF; +import static com.facebook.presto.sql.relational.Signatures.IN; +import static com.facebook.presto.sql.relational.Signatures.IS_NULL; +import static com.facebook.presto.sql.relational.Signatures.NULL_IF; +import static com.facebook.presto.sql.relational.Signatures.ROW_CONSTRUCTOR; +import static com.facebook.presto.sql.relational.Signatures.SWITCH; +import static com.facebook.presto.sql.relational.Signatures.TRY; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.lang.String.format; @@ -157,7 +172,40 @@ public static BytecodeExpression loadConstant(Binding binding) binding.getType().returnType()); } - public static BytecodeNode generateInvocation(Scope scope, String name, ScalarFunctionImplementation function, Optional instance, List arguments, Binding binding) + public static boolean directWrittenToBlock(RowExpression expression, FunctionRegistry registry) + { + if (!(expression instanceof CallExpression)) { + return false; + } + + CallExpression call = (CallExpression) expression; + if (call.getSignature().getName().equals(CAST)) { + return false; + } + else { + switch (call.getSignature().getName()) { + // lazy evaluation + case IF: + case NULL_IF: + case SWITCH: + case TRY: + case IS_NULL: + case COALESCE: + case IN: + case "AND": + case "OR": + case DEREFERENCE: + case ROW_CONSTRUCTOR: + case BIND: + return false; + default: + ScalarFunctionImplementation function = registry.getScalarFunctionImplementation(call.getSignature()); + return function.isWriteToBlockBuilderParamater(); + } + } + } + + public static BytecodeNode generateInvocation(Scope scope, String name, ScalarFunctionImplementation function, Optional instance, Optional outputBlockBuilder, List arguments, Binding binding) { MethodType methodType = binding.getType(); @@ -175,6 +223,13 @@ public static BytecodeNode generateInvocation(Scope scope, String name, ScalarFu // Index of current parameter in the MethodHandle int currentParameterIndex = 0; + if (function.isWriteToBlockBuilderParamater()) { + checkState(outputBlockBuilder.isPresent()); + block.append(outputBlockBuilder.get()); + Class type = methodType.parameterArray()[currentParameterIndex]; + stackTypes.add(type); + currentParameterIndex++; + } // Index of parameter (without @IsNull) in Presto function int realParameterIndex = 0; @@ -215,11 +270,17 @@ else if (type == ConnectorSession.class) { } currentParameterIndex++; } + block.append(invoke(binding, name)); - if (function.isNullable()) { + if (function.isNullable() && !function.isWriteToBlockBuilderParamater()) { block.append(unboxPrimitiveIfNecessary(scope, returnType)); } + if (!function.isWriteToBlockBuilderParamater() && outputBlockBuilder.isPresent()) { + // Result is on the stack, append it to the BlockBuilder + throw new NotImplementedException(); + } + block.visitLabel(end); return block; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java index cdf94b9b02403..616967d8a7bc4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java @@ -33,6 +33,6 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .getRegistry() .getCoercion(argument.getType(), returnType); - return generatorContext.generateCall(function.getName(), generatorContext.getRegistry().getScalarFunctionImplementation(function), ImmutableList.of(generatorContext.generate(argument))); + return generatorContext.generateCall(function.getName(), generatorContext.getRegistry().getScalarFunctionImplementation(function), ImmutableList.of(generatorContext.generate(argument, generatorContext.getOutputBlockBuilder()))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java index ae30f368324ed..73352a26acc0b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CoalesceCodeGenerator.java @@ -36,7 +36,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon { List operands = new ArrayList<>(); for (RowExpression expression : arguments) { - operands.add(generatorContext.generate(expression)); + operands.add(generatorContext.generate(expression, generatorContext.getOutputBlockBuilder())); } Variable wasNull = generatorContext.wasNull(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java index 2bd46ac9ec908..ed90092d66eac 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java @@ -47,6 +47,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.bytecode.Access.PUBLIC; @@ -54,6 +55,7 @@ import static com.facebook.presto.bytecode.OpCode.NOP; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; +import static com.facebook.presto.sql.gen.BytecodeUtils.directWrittenToBlock; import static com.facebook.presto.sql.gen.BytecodeUtils.generateWrite; import static com.facebook.presto.sql.gen.LambdaAndTryExpressionExtractor.extractLambdaAndTryExpressions; import static com.facebook.presto.sql.gen.TryCodeGenerator.defineTryMethod; @@ -289,7 +291,7 @@ private void generateFilterMethod( .comment("boolean wasNull = false;") .putVariable(wasNullVariable, false) .comment("evaluate filter: " + filter) - .append(compiler.compile(filter, scope)) + .append(compiler.compile(filter, scope, Optional.empty())) .comment("if (wasNull) return false;") .getVariable(wasNullVariable) .ifFalseGoto(end) @@ -324,14 +326,25 @@ private void generateProjectMethod( metadata.getFunctionRegistry(), preGeneratedExpressions); - method.getBody() - .comment("boolean wasNull = false;") - .putVariable(wasNullVariable, false) - .getVariable(output) - .comment("evaluate projection: " + projection.toString()) - .append(compiler.compile(projection, scope)) - .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) - .ret(); + if (!directWrittenToBlock(projection, metadata.getFunctionRegistry())) { + method.getBody() + .comment("boolean wasNull = false;") + .putVariable(wasNullVariable, false) + .getVariable(output) + .comment("evaluate projection: " + projection.toString()) + .append(compiler.compile(projection, scope, Optional.empty())) + .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) + .ret(); + } + else { + method.getBody() + .comment("boolean wasNull = false;") + .putVariable(wasNullVariable, false) + .getVariable(output) + .comment("evaluate projection: " + projection.toString()) + .append(compiler.compile(projection, scope, Optional.of(output))) + .ret(); + } } private static RowExpressionVisitor fieldReferenceCompiler(Variable cursorVariable) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java index 04077d55ac871..fc98f0ba9db47 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/DereferenceCodeGenerator.java @@ -47,7 +47,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon // clear the wasNull flag before evaluating the row value block.putVariable(wasNull, false); - block.append(generator.generate(arguments.get(0))).putVariable(rowBlock); + block.append(generator.generate(arguments.get(0), generator.getOutputBlockBuilder())).putVariable(rowBlock); IfStatement ifRowBlockIsNull = new IfStatement("if row block is null...") .condition(wasNull); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java index 01e83080f5708..600d62e5fd5fd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; public class FunctionCallCodeGenerator implements BytecodeGenerator @@ -36,7 +37,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon List argumentsBytecode = new ArrayList<>(); for (int i = 0; i < arguments.size(); i++) { RowExpression argument = arguments.get(i); - argumentsBytecode.add(context.generate(argument, function.getLambdaInterface().get(i))); + argumentsBytecode.add(context.generate(argument, Optional.empty(), function.getLambdaInterface().get(i))); } return context.generateCall(signature.getName(), function, argumentsBytecode); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java index b1df3df8876bd..c895de90133e7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/IfCodeGenerator.java @@ -36,7 +36,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon Variable wasNull = context.wasNull(); BytecodeBlock condition = new BytecodeBlock() - .append(context.generate(arguments.get(0))) + .append(context.generate(arguments.get(0), context.getOutputBlockBuilder())) .comment("... and condition value was not null") .append(wasNull) .invokeStatic(CompilerOperations.class, "not", boolean.class, boolean.class) @@ -45,7 +45,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon return new IfStatement() .condition(condition) - .ifTrue(context.generate(arguments.get(1))) - .ifFalse(context.generate(arguments.get(2))); + .ifTrue(context.generate(arguments.get(1), context.getOutputBlockBuilder())) + .ifFalse(context.generate(arguments.get(2), context.getOutputBlockBuilder())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java index 90bb8b6571442..7ff381b6bc8c8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java @@ -109,13 +109,13 @@ static SwitchGenerationCase checkSwitchGenerationCase(Type type, List arguments) { - BytecodeNode value = generatorContext.generate(arguments.get(0)); + BytecodeNode value = generatorContext.generate(arguments.get(0), generatorContext.getOutputBlockBuilder()); List values = arguments.subList(1, arguments.size()); ImmutableList.Builder valuesBytecode = ImmutableList.builder(); for (int i = 1; i < arguments.size(); i++) { - BytecodeNode testNode = generatorContext.generate(arguments.get(i)); + BytecodeNode testNode = generatorContext.generate(arguments.get(i), generatorContext.getOutputBlockBuilder()); valuesBytecode.add(testNode); } @@ -132,7 +132,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon ImmutableSet.Builder constantValuesBuilder = ImmutableSet.builder(); for (RowExpression testValue : values) { - BytecodeNode testBytecode = generatorContext.generate(testValue); + BytecodeNode testBytecode = generatorContext.generate(testValue, generatorContext.getOutputBlockBuilder()); if (testValue instanceof ConstantExpression && ((ConstantExpression) testValue).getValue() != null) { ConstantExpression constant = (ConstantExpression) testValue; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/InvokeFunctionBytecodeExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/InvokeFunctionBytecodeExpression.java index e6d3820f81154..92d41c0e71c41 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/InvokeFunctionBytecodeExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/InvokeFunctionBytecodeExpression.java @@ -65,7 +65,7 @@ private InvokeFunctionBytecodeExpression( { super(type(function.getMethodHandle().type().returnType())); - this.invocation = generateInvocation(scope, name, function, instance, parameters.stream().map(BytecodeNode.class::cast).collect(toImmutableList()), binding); + this.invocation = generateInvocation(scope, name, function, instance, Optional.empty(), parameters.stream().map(BytecodeNode.class::cast).collect(toImmutableList()), binding); this.oneLineDescription = name + "(" + Joiner.on(", ").join(parameters) + ")"; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java index 38b904b6851e8..93c5aeaacf96b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/IsNullCodeGenerator.java @@ -40,7 +40,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon return loadBoolean(true); } - BytecodeNode value = generatorContext.generate(argument); + BytecodeNode value = generatorContext.generate(argument, generatorContext.getOutputBlockBuilder()); // evaluate the expression, pop the produced value, and load the null flag Variable wasNull = generatorContext.wasNull(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java index ecff33aba4465..1a5d24372b0ec 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java @@ -217,7 +217,7 @@ private void generateFilterMethod( metadata.getFunctionRegistry(), preGeneratedExpressions); - BytecodeNode visitorBody = compiler.compile(filter, scope); + BytecodeNode visitorBody = compiler.compile(filter, scope, Optional.empty()); Variable result = scope.declareVariable(boolean.class, "result"); body.append(visitorBody) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java index 4893c753de86c..c87ebdf45ca95 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java @@ -47,6 +47,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Optional; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; @@ -132,7 +133,7 @@ private static CompiledLambda defineLambdaMethodAndField( Scope scope = method.getScope(); Variable wasNull = scope.declareVariable(boolean.class, "wasNull"); - BytecodeNode compiledBody = innerExpressionCompiler.compile(lambda.getBody(), scope); + BytecodeNode compiledBody = innerExpressionCompiler.compile(lambda.getBody(), scope, Optional.empty()); method.getBody() .putVariable(wasNull, false) .append(compiledBody) @@ -193,7 +194,7 @@ public static BytecodeNode generateLambda( for (RowExpression captureExpression : captureExpressions) { Class valueType = Primitives.wrap(captureExpression.getType().getJavaType()); Variable valueVariable = scope.createTempVariable(valueType); - block.append(context.generate(captureExpression)); + block.append(context.generate(captureExpression, Optional.empty())); block.append(boxPrimitiveIfNecessary(scope, valueType)); block.putVariable(valueVariable); block.append(wasNull.set(constantFalse())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java index dfe495463fb3a..0931dc7f8d14c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java @@ -46,7 +46,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon // push first arg on the stack BytecodeBlock block = new BytecodeBlock() .comment("check if first arg is null") - .append(generatorContext.generate(first)) + .append(generatorContext.generate(first, generatorContext.getOutputBlockBuilder())) .append(BytecodeUtils.ifWasNullPopAndGoto(scope, notMatch, void.class)); Type firstType = first.getType(); @@ -60,7 +60,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon equalsFunction, ImmutableList.of( cast(generatorContext, new BytecodeBlock().dup(firstType.getJavaType()), firstType, equalsSignature.getArgumentTypes().get(0)), - cast(generatorContext, generatorContext.generate(second), secondType, equalsSignature.getArgumentTypes().get(1)))); + cast(generatorContext, generatorContext.generate(second, generatorContext.getOutputBlockBuilder()), secondType, equalsSignature.getArgumentTypes().get(1)))); BytecodeBlock conditionBlock = new BytecodeBlock() .append(equalsCall) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java index 0687ecc5c0334..0d84e871004dd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/OrCodeGenerator.java @@ -40,8 +40,8 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon .comment("OR") .setDescription("OR"); - BytecodeNode left = generator.generate(arguments.get(0)); - BytecodeNode right = generator.generate(arguments.get(1)); + BytecodeNode left = generator.generate(arguments.get(0), generator.getOutputBlockBuilder()); + BytecodeNode right = generator.generate(arguments.get(1), generator.getOutputBlockBuilder()); block.append(left); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java index d2a0d0f36fd53..9afe677307472 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java @@ -92,6 +92,7 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.not; import static com.facebook.presto.operator.project.PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters; import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR; +import static com.facebook.presto.sql.gen.BytecodeUtils.directWrittenToBlock; import static com.facebook.presto.sql.gen.BytecodeUtils.generateWrite; import static com.facebook.presto.sql.gen.BytecodeUtils.invoke; import static com.facebook.presto.sql.gen.LambdaAndTryExpressionExtractor.extractLambdaAndTryExpressions; @@ -369,10 +370,18 @@ private MethodDefinition generateProjectMethod( metadata.getFunctionRegistry(), preGeneratedExpressions); - body.append(thisVariable.getField(blockBuilder)) - .append(compiler.compile(projection, scope)) - .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) - .ret(); + if (!directWrittenToBlock(projection, metadata.getFunctionRegistry())) { + body.append(thisVariable.getField(blockBuilder)) + .append(compiler.compile(projection, scope, Optional.empty())) + .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) + .ret(); + } + else { + Variable blockBuilderVariable = scope.createTempVariable(BlockBuilder.class); + body.append(blockBuilderVariable.set(thisVariable.getField(blockBuilder))) + .append(compiler.compile(projection, scope, Optional.of(blockBuilderVariable))) + .ret(); + } return method; } @@ -546,7 +555,7 @@ private MethodDefinition generateFilterMethod( preGeneratedExpressions); Variable result = scope.declareVariable(boolean.class, "result"); - body.append(compiler.compile(filter, scope)) + body.append(compiler.compile(filter, scope, Optional.empty())) // store result so we can check for null .putVariable(result) .append(and(not(wasNullVariable), result).ret()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java index 5b86d03c394e8..042750dfe4372 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowConstructorCodeGenerator.java @@ -27,6 +27,7 @@ import com.facebook.presto.sql.relational.RowExpression; import java.util.List; +import java.util.Optional; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; @@ -62,7 +63,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon Variable field = scope.createTempVariable(javaType); block.comment("Clean wasNull and Generate + " + i + "-th field of row"); block.append(context.wasNull().set(constantFalse())); - block.append(context.generate(arguments.get(i))); + block.append(context.generate(arguments.get(i), context.getOutputBlockBuilder())); block.putVariable(field); block.append(new IfStatement() .condition(context.wasNull()) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java index 5fcc7165d5336..4bbbecc512ff1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java @@ -16,6 +16,7 @@ import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.Scope; +import com.facebook.presto.bytecode.Variable; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.ConstantExpression; @@ -73,14 +74,14 @@ public class RowExpressionCompiler this.preGeneratedExpressions = preGeneratedExpressions; } - public BytecodeNode compile(RowExpression rowExpression, Scope scope) + public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional outputBlockBuilder) { - return compile(rowExpression, scope, Optional.empty()); + return compile(rowExpression, scope, outputBlockBuilder, Optional.empty()); } - public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional lambdaInterface) + public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional outputBlockBuilder, Optional lambdaInterface) { - return rowExpression.accept(new Visitor(), new Context(scope, lambdaInterface)); + return rowExpression.accept(new Visitor(), new Context(scope, outputBlockBuilder, lambdaInterface)); } private class Visitor @@ -148,6 +149,7 @@ public BytecodeNode visitCall(CallExpression call, Context context) callSiteBinder, cachedInstanceBinder, registry, + context.getOutputBlockBuilder(), preGeneratedExpressions); return generator.generateExpression(call.getSignature(), generatorContext, call.getType(), call.getArguments()); @@ -220,6 +222,7 @@ public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context conte callSiteBinder, cachedInstanceBinder, registry, + context.getOutputBlockBuilder(), preGeneratedExpressions); return generateLambda( @@ -239,11 +242,13 @@ public BytecodeNode visitVariableReference(VariableReferenceExpression reference private static class Context { private final Scope scope; + private final Optional outputBlockBuilder; private final Optional lambdaInterface; - public Context(Scope scope, Optional lambdaInterface) + public Context(Scope scope, Optional outputBlockBuilder, Optional lambdaInterface) { this.scope = scope; + this.outputBlockBuilder = outputBlockBuilder; this.lambdaInterface = lambdaInterface; } @@ -252,6 +257,11 @@ public Scope getScope() return scope; } + public Optional getOutputBlockBuilder() + { + return outputBlockBuilder; + } + public Optional getLambdaInterface() { return lambdaInterface; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java index 98157b54a28ad..7e1f0a84d8eab 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java @@ -30,6 +30,7 @@ import com.google.common.collect.Lists; import java.util.List; +import java.util.Optional; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue; @@ -74,7 +75,7 @@ else if ( == ) { // process value, else, and all when clauses RowExpression value = arguments.get(0); - BytecodeNode valueBytecode = generatorContext.generate(value); + BytecodeNode valueBytecode = generatorContext.generate(value, generatorContext.getOutputBlockBuilder()); BytecodeNode elseValue; List whenClauses; @@ -87,7 +88,7 @@ else if ( == ) { } else { whenClauses = arguments.subList(1, arguments.size() - 1); - elseValue = generatorContext.generate(last); + elseValue = generatorContext.generate(last, generatorContext.getOutputBlockBuilder()); } // determine the type of the value and result @@ -123,7 +124,7 @@ else if ( == ) { BytecodeNode equalsCall = generatorContext.generateCall( equalsFunction.getName(), generatorContext.getRegistry().getScalarFunctionImplementation(equalsFunction), - ImmutableList.of(generatorContext.generate(operand), getTempVariableNode)); + ImmutableList.of(generatorContext.generate(operand, generatorContext.getOutputBlockBuilder()), getTempVariableNode)); BytecodeBlock condition = new BytecodeBlock() .append(equalsCall) @@ -131,7 +132,7 @@ else if ( == ) { elseValue = new IfStatement("when") .condition(condition) - .ifTrue(generatorContext.generate(result)) + .ifTrue(generatorContext.generate(result, generatorContext.getOutputBlockBuilder())) .ifFalse(elseValue); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java index 68cdde929eb66..0c29a730d5a84 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java @@ -34,6 +34,7 @@ import java.lang.invoke.MethodType; import java.util.List; import java.util.Map; +import java.util.Optional; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.a; @@ -102,7 +103,7 @@ public static MethodDefinition defineTryMethod( Scope calleeMethodScope = method.getScope(); Variable wasNull = calleeMethodScope.declareVariable(boolean.class, "wasNull"); - BytecodeNode innerExpression = innerExpressionCompiler.compile(innerRowExpression, calleeMethodScope); + BytecodeNode innerExpression = innerExpressionCompiler.compile(innerRowExpression, calleeMethodScope, Optional.empty()); MethodType exceptionHandlerType = methodType(returnType, PrestoException.class); MethodHandle exceptionHandler = EXCEPTION_HANDLER.asType(exceptionHandlerType); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpression.java index 7541d74d783c8..35dc62dc783b2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpression.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.relational; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.type.Type; public abstract class RowExpression @@ -29,4 +30,9 @@ public abstract class RowExpression public abstract String toString(); public abstract R accept(RowExpressionVisitor visitor, C context); + + public boolean directWriteToOutputBuilder(FunctionRegistry functionRegistry) + { + return false; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayIdentity.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayIdentity.java new file mode 100644 index 0000000000000..e1f9bb3cb5dd1 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayIdentity.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionKind; +import com.facebook.presto.metadata.FunctionListBuilder; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.operator.project.PageProcessor; +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.gen.ExpressionCompiler; +import com.facebook.presto.sql.gen.PageFunctionCompiler; +import com.facebook.presto.sql.relational.CallExpression; +import com.facebook.presto.sql.relational.LambdaDefinitionExpression; +import com.facebook.presto.sql.relational.RowExpression; +import com.facebook.presto.sql.relational.VariableReferenceExpression; +import com.google.common.base.Throwables; +import com.google.common.base.Verify; +import com.google.common.collect.ImmutableList; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.lang.invoke.MethodHandle; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.metadata.Signature.typeVariable; +import static com.facebook.presto.operator.scalar.BenchmarkArrayFilter.ExactArrayFilterFunction.EXACT_ARRAY_FILTER_FUNCTION; +import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; +import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.relational.Expressions.field; +import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static com.facebook.presto.util.Reflection.methodHandle; +import static java.lang.Boolean.TRUE; + +@SuppressWarnings("MethodMayBeStatic") +@State(Scope.Thread) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Fork(2) +@Warmup(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkArrayIdentity +{ + private static final int POSITIONS = 100_000; + private static final int ARRAY_SIZE = 4; + private static final int NUM_TYPES = 1; + private static final List TYPES = ImmutableList.of(BIGINT); + + static { + Verify.verify(NUM_TYPES == TYPES.size()); + } + + @Benchmark + @OperationsPerInvocation(POSITIONS * ARRAY_SIZE * NUM_TYPES) + public List benchmark(BenchmarkData data) + throws Throwable + { + return ImmutableList.copyOf(data.getPageProcessor().process(SESSION, data.getPage())); + } + + @SuppressWarnings("FieldMayBeFinal") + @State(Scope.Thread) + public static class BenchmarkData + { + @Param({"array_identity", "array_identity_direct"}) + private String name = "array_identity"; + + private Page page; + private PageProcessor pageProcessor; + + @Setup + public void setup() + { + MetadataManager metadata = MetadataManager.createTestMetadataManager(); + metadata.addFunctions(new FunctionListBuilder().function(EXACT_ARRAY_FILTER_FUNCTION).getFunctions()); + ExpressionCompiler compiler = new ExpressionCompiler(metadata, new PageFunctionCompiler(metadata, 0)); + ImmutableList.Builder projectionsBuilder = ImmutableList.builder(); + Block[] blocks = new Block[TYPES.size()]; + for (int i = 0; i < TYPES.size(); i++) { + Type elementType = TYPES.get(i); + ArrayType arrayType = new ArrayType(elementType); + Signature signature = new Signature(name, FunctionKind.SCALAR, arrayType.getTypeSignature(), arrayType.getTypeSignature()); + projectionsBuilder.add(new CallExpression(signature, arrayType, ImmutableList.of(field(0, arrayType)))); + blocks[i] = createChannel(POSITIONS, ARRAY_SIZE, arrayType); + } + + ImmutableList projections = projectionsBuilder.build(); + pageProcessor = compiler.compilePageProcessor(Optional.empty(), projections).get(); + page = new Page(blocks); + } + + private static Block createChannel(int positionCount, int arraySize, ArrayType arrayType) + { + BlockBuilder blockBuilder = arrayType.createBlockBuilder(new BlockBuilderStatus(), positionCount); + for (int position = 0; position < positionCount; position++) { + BlockBuilder entryBuilder = blockBuilder.beginBlockEntry(); + for (int i = 0; i < arraySize; i++) { + if (arrayType.getElementType().getJavaType() == long.class) { + arrayType.getElementType().writeLong(entryBuilder, ThreadLocalRandom.current().nextLong()); + } + else { + throw new UnsupportedOperationException(); + } + } + blockBuilder.closeEntry(); + } + return blockBuilder.build(); + } + + public PageProcessor getPageProcessor() + { + return pageProcessor; + } + + public Page getPage() + { + return page; + } + } + + public static void main(String[] args) + throws Throwable + { + // assure the benchmarks are valid before running + BenchmarkData data = new BenchmarkData(); + data.setup(); + new BenchmarkArrayIdentity().benchmark(data); + + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkArrayIdentity.class.getSimpleName() + ".*") + .build(); + new Runner(options).run(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestVarArgsToArrayAdapterGenerator.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestVarArgsToArrayAdapterGenerator.java index 9714a705e6d6e..92321e0cf6cc1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestVarArgsToArrayAdapterGenerator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestVarArgsToArrayAdapterGenerator.java @@ -119,6 +119,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in nCopies(arity, Optional.empty()), methodHandleAndConstructor.getMethodHandle(), Optional.of(methodHandleAndConstructor.getConstructor()), + false, isDeterministic()); } diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java index 42d364bad440f..863328745430a 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java @@ -1165,6 +1165,19 @@ public void assertInvalidFunction(String projection, SemanticErrorCode errorCode } } + @Test + public void testIdentity() + { + assertFunction("identity_array(ARRAY[1, 2, 3])", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + } + + @Test + public void testIdentityDirect() + { + assertFunction("identity_array_direct(ARRAY[1, 2, 3])", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + } + + @Test public void testFlatten() {