Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.prestosql.operator.scalar.ScalarFunctionImplementation;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeManager;
Expand Down Expand Up @@ -124,7 +125,7 @@ private ScalarImplementationChoice getScalarFunctionImplementationChoice(

List<Object> extraParameters = computeExtraParameters(matchingMethodsGroup.get(), context);
MethodHandle methodHandle = applyExtraParameters(matchingMethod.get().getMethod(), extraParameters, choice.getArgumentProperties());
return new ScalarImplementationChoice(choice.isNullableResult(), choice.getArgumentProperties(), methodHandle, Optional.empty());
return new ScalarImplementationChoice(choice.isNullableResult(), choice.getArgumentProperties(), choice.getReturnPlaceConvention(), methodHandle, Optional.empty());
}

private static boolean matchesParameterAndReturnTypes(
Expand Down Expand Up @@ -218,15 +219,18 @@ static final class PolymorphicScalarFunctionChoice
{
private final boolean nullableResult;
private final List<ArgumentProperty> argumentProperties;
private final ReturnPlaceConvention returnPlaceConvention;
private final List<MethodsGroup> methodsGroups;

PolymorphicScalarFunctionChoice(
boolean nullableResult,
List<ArgumentProperty> argumentProperties,
ReturnPlaceConvention returnPlaceConvention,
List<MethodsGroup> methodsGroups)
{
this.nullableResult = nullableResult;
this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null"));
this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null");
this.methodsGroups = ImmutableList.copyOf(requireNonNull(methodsGroups, "methodsWithExtraParametersFunctions is null"));
}

Expand All @@ -244,5 +248,10 @@ List<ArgumentProperty> getArgumentProperties()
{
return argumentProperties;
}

ReturnPlaceConvention getReturnPlaceConvention()
{
return returnPlaceConvention;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import io.prestosql.metadata.PolymorphicScalarFunction.PolymorphicScalarFunctionChoice;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeManager;
Expand Down Expand Up @@ -250,6 +251,7 @@ public static class ChoiceBuilder
private final Signature signature;
private boolean nullableResult;
private List<ArgumentProperty> argumentProperties;
private ReturnPlaceConvention returnPlaceConvention;
private final ImmutableList.Builder<MethodsGroup> methodsGroups = ImmutableList.builder();

private ChoiceBuilder(Class<?> clazz, Signature signature)
Expand All @@ -264,6 +266,10 @@ public ChoiceBuilder implementation(Function<MethodsGroupBuilder, MethodsGroupBu
if (argumentProperties == null) {
argumentProperties = nCopies(signature.getArgumentTypes().size(), valueTypeArgumentProperty(RETURN_NULL_ON_NULL));
}
// if the returnPlaceConvention is not set yet. We assume it is set to the default value.
if (returnPlaceConvention == null) {
returnPlaceConvention = ReturnPlaceConvention.STACK;
}
MethodsGroupBuilder methodsGroupBuilder = new MethodsGroupBuilder(clazz, signature, argumentProperties);
methodsGroupSpecification.apply(methodsGroupBuilder);
methodsGroups.add(methodsGroupBuilder.build());
Expand All @@ -285,9 +291,18 @@ public ChoiceBuilder argumentProperties(ArgumentProperty... argumentProperties)
return this;
}

public ChoiceBuilder returnPlaceConvention(ReturnPlaceConvention returnPlaceConvention)
{
requireNonNull(returnPlaceConvention, "returnPlaceConvention is null");
checkState(this.returnPlaceConvention == null,
"The `returnPlaceConvention` method must be invoked only once, and must be invoked before the `implementation` method");
this.returnPlaceConvention = returnPlaceConvention;
return this;
}

public PolymorphicScalarFunctionChoice build()
{
return new PolymorphicScalarFunctionChoice(nullableResult, argumentProperties, methodsGroups.build());
return new PolymorphicScalarFunctionChoice(nullableResult, argumentProperties, returnPlaceConvention, methodsGroups.build());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.prestosql.metadata.FunctionRegistry;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlOperator;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.function.InvocationConvention;
Expand Down Expand Up @@ -78,11 +79,13 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in
new ScalarImplementationChoice(
false,
ImmutableList.of(valueTypeArgumentProperty(USE_NULL_FLAG), valueTypeArgumentProperty(USE_NULL_FLAG)),
ReturnPlaceConvention.STACK,
METHOD_HANDLE_NULL_FLAG.bindTo(type).bindTo(argumentMethods.build()),
Optional.empty()),
new ScalarImplementationChoice(
false,
ImmutableList.of(valueTypeArgumentProperty(BLOCK_AND_POSITION), valueTypeArgumentProperty(BLOCK_AND_POSITION)),
ReturnPlaceConvention.STACK,
METHOD_HANDLE_BLOCK_POSITION.bindTo(type).bindTo(argumentMethods.build()),
Optional.empty())),
isDeterministic());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ public ScalarFunctionImplementation(
Optional<MethodHandle> instanceFactory,
boolean deterministic)
{
this(ImmutableList.of(new ScalarImplementationChoice(nullable, argumentProperties, methodHandle, instanceFactory)), deterministic);
this(
ImmutableList.of(new ScalarImplementationChoice(
nullable,
argumentProperties,
ReturnPlaceConvention.STACK,
methodHandle,
instanceFactory)),
deterministic);
}

/**
Expand Down Expand Up @@ -107,18 +114,21 @@ public static class ScalarImplementationChoice
{
private final boolean nullable;
private final List<ArgumentProperty> argumentProperties;
private final ReturnPlaceConvention returnPlaceConvention;
private final MethodHandle methodHandle;
private final Optional<MethodHandle> instanceFactory;
private final boolean hasSession;

public ScalarImplementationChoice(
boolean nullable,
List<ArgumentProperty> argumentProperties,
ReturnPlaceConvention returnPlaceConvention,
MethodHandle methodHandle,
Optional<MethodHandle> instanceFactory)
{
this.nullable = nullable;
this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null"));
this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null");
this.methodHandle = requireNonNull(methodHandle, "methodHandle is null");
this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null");

Expand Down Expand Up @@ -158,6 +168,11 @@ public ArgumentProperty getArgumentProperty(int argumentIndex)
return argumentProperties.get(argumentIndex);
}

public ReturnPlaceConvention getReturnPlaceConvention()
{
return returnPlaceConvention;
}

public MethodHandle getMethodHandle()
{
return methodHandle;
Expand Down Expand Up @@ -279,4 +294,10 @@ public enum ArgumentType
VALUE_TYPE,
FUNCTION_TYPE
}

public enum ReturnPlaceConvention
{
STACK,
PROVIDED_BLOCKBUILDER
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.prestosql.operator.scalar.ScalarFunctionImplementation;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.connector.ConnectorSession;
Expand Down Expand Up @@ -164,6 +165,7 @@ public Optional<ScalarFunctionImplementation> specialize(Signature boundSignatur
implementationChoices.add(new ScalarImplementationChoice(
choice.nullable,
choice.argumentProperties,
choice.returnPlaceConvention,
boundMethodHandle.asType(javaMethodType(choice, boundSignature, typeManager)),
boundConstructor));
}
Expand Down Expand Up @@ -305,6 +307,7 @@ public static final class ParametricScalarImplementationChoice
{
private final boolean nullable;
private final List<ArgumentProperty> argumentProperties;
private final ReturnPlaceConvention returnPlaceConvention;
private final MethodHandle methodHandle;
private final Optional<MethodHandle> constructor;
private final List<ImplementationDependency> dependencies;
Expand All @@ -316,6 +319,7 @@ private ParametricScalarImplementationChoice(
boolean nullable,
boolean hasConnectorSession,
List<ArgumentProperty> argumentProperties,
ReturnPlaceConvention returnPlaceConvention,
MethodHandle methodHandle,
Optional<MethodHandle> constructor,
List<ImplementationDependency> dependencies,
Expand All @@ -324,6 +328,7 @@ private ParametricScalarImplementationChoice(
this.nullable = nullable;
this.hasConnectorSession = hasConnectorSession;
this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null"));
this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null");
this.methodHandle = requireNonNull(methodHandle, "methodHandle is null");
this.constructor = requireNonNull(constructor, "constructor is null");
this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null"));
Expand Down Expand Up @@ -364,6 +369,11 @@ public List<ArgumentProperty> getArgumentProperties()
return argumentProperties;
}

public ReturnPlaceConvention getReturnPlaceConvention()
{
return returnPlaceConvention;
}

public boolean checkDependencies()
{
for (int i = 1; i < getDependencies().size(); i++) {
Expand Down Expand Up @@ -503,7 +513,15 @@ else if (actualReturnType.isPrimitive()) {

this.methodHandle = getMethodHandle(method);

ParametricScalarImplementationChoice choice = new ParametricScalarImplementationChoice(nullable, hasConnectorSession, argumentProperties, methodHandle, constructorMethodHandle, dependencies, constructorDependencies);
ParametricScalarImplementationChoice choice = new ParametricScalarImplementationChoice(
nullable,
hasConnectorSession,
argumentProperties,
ReturnPlaceConvention.STACK, // TODO: support other return place convention
methodHandle,
constructorMethodHandle,
dependencies,
constructorDependencies);
choices.add(choice);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@
import io.prestosql.sql.relational.RowExpression;

import java.util.List;
import java.util.Optional;

import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse;
import static io.prestosql.sql.gen.BytecodeGenerator.generateWrite;

public class AndCodeGenerator
implements BytecodeGenerator
{
@Override
public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List<RowExpression> arguments)
public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generator, Type returnType, List<RowExpression> arguments, Optional<Variable> outputBlockVariable)
{
Preconditions.checkArgument(arguments.size() == 2);

Expand All @@ -40,8 +42,8 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon
.comment("AND")
.setDescription("AND");

BytecodeNode left = generator.generate(arguments.get(0));
BytecodeNode right = generator.generate(arguments.get(1));
BytecodeNode left = generator.generate(arguments.get(0), Optional.empty());
BytecodeNode right = generator.generate(arguments.get(1), Optional.empty());

block.append(left);

Expand Down Expand Up @@ -97,6 +99,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon
block.append(ifRightIsNull)
.visitLabel(end);

outputBlockVariable.ifPresent(output -> block.append(generateWrite(generator, returnType, output)));
return block;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package io.prestosql.sql.gen;

import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.Variable;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.gen.LambdaBytecodeGenerator.CompiledLambda;
Expand All @@ -23,7 +24,9 @@

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;

public class BindCodeGenerator
Expand All @@ -39,12 +42,18 @@ public BindCodeGenerator(Map<LambdaDefinitionExpression, CompiledLambda> compile
}

@Override
public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List<RowExpression> arguments)
public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List<RowExpression> arguments, Optional<Variable> outputBlockVariable)
{
// Bind expression is used to generate captured lambda.
// It takes the captured values and the uncaptured lambda, and produces captured lambda as the output.
// The uncaptured lambda is just a method, and does not have a stack representation during execution.
// As a result, the bind expression generates the captured lambda in one step.

// outputBlockVariable cannot present because
// 1. bind cannot be in the top level of an expression
// 2. lambda cannot be put into blocks.
checkArgument(!outputBlockVariable.isPresent());

int numCaptures = arguments.size() - 1;
LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) arguments.get(numCaptures);
checkState(compiledLambdaMap.containsKey(lambda), "lambda expressions map does not contain this lambda definition");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,25 @@
package io.prestosql.sql.gen;

import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.Variable;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.relational.RowExpression;

import java.util.List;
import java.util.Optional;

public interface BytecodeGenerator
{
BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List<RowExpression> arguments);
BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List<RowExpression> arguments, Optional<Variable> outputBlockVariable);

static BytecodeNode generateWrite(BytecodeGeneratorContext context, Type returnType, Variable outputBlock)
{
return BytecodeUtils.generateWrite(
context.getCallSiteBinder(),
context.getScope(),
context.getScope().getVariable("wasNull"),
returnType,
outputBlock);
}
}
Loading