Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.metadata.SqlScalarFunction;
import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation.ArgumentProperty;
import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation.ReturnPlaceConvention;
import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation.ScalarImplementationChoice;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.PrestoException;
Expand Down Expand Up @@ -46,6 +48,7 @@
import static com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty;
import static com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL;
import static com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation.NullConvention.USE_BOXED_TYPE;
import static com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation.ReturnPlaceConvention.PROVIDED_BLOCKBUILDER;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.function.Signature.typeVariable;
Expand All @@ -64,15 +67,25 @@ public final class ArrayJoin
private static final TypeSignature VARCHAR_TYPE_SIGNATURE = VARCHAR.getTypeSignature();
private static final String FUNCTION_NAME = "array_join";
private static final String DESCRIPTION = "Concatenates the elements of the given array using a delimiter and an optional string to replace nulls";
private static final MethodHandle METHOD_HANDLE = methodHandle(

private static final MethodHandle METHOD_HANDLE_STACK = methodHandle(
ArrayJoin.class,
"arrayJoin",
"arrayJoinStack",
MethodHandle.class,
Object.class,
ConnectorSession.class,
Block.class,
Slice.class);

private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK = methodHandle(
ArrayJoin.class,
"arrayJoinProvidedBlock",
MethodHandle.class,
ConnectorSession.class,
BlockBuilder.class,
Block.class,
Slice.class);

private static final MethodHandle GET_BOOLEAN = methodHandle(Type.class, "getBoolean", Block.class, int.class);
private static final MethodHandle GET_DOUBLE = methodHandle(Type.class, "getDouble", Block.class, int.class);
private static final MethodHandle GET_LONG = methodHandle(Type.class, "getLong", Block.class, int.class);
Expand All @@ -83,16 +96,26 @@ public final class ArrayJoin
public static class ArrayJoinWithNullReplacement
extends SqlScalarFunction
{
private static final MethodHandle METHOD_HANDLE = methodHandle(
private static final MethodHandle METHOD_HANDLE_STACK = methodHandle(
ArrayJoin.class,
"arrayJoin",
"arrayJoinStack",
MethodHandle.class,
Object.class,
ConnectorSession.class,
Block.class,
Slice.class,
Slice.class);

private static final MethodHandle METHOD_HANDLE_PROVIDED_BLOCK = methodHandle(
ArrayJoin.class,
"arrayJoinProvidedBlock",
MethodHandle.class,
ConnectorSession.class,
BlockBuilder.class,
Block.class,
Slice.class,
Slice.class);

public ArrayJoinWithNullReplacement()
{
super(new Signature(
Expand Down Expand Up @@ -126,7 +149,12 @@ public String getDescription()
@Override
public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionManager functionManager)
{
return specializeArrayJoin(boundVariables.getTypeVariables(), functionManager, ImmutableList.of(false, false, false), METHOD_HANDLE);
return specializeArrayJoin(
boundVariables.getTypeVariables(),
functionManager,
ImmutableList.of(false, false, false),
METHOD_HANDLE_STACK,
METHOD_HANDLE_PROVIDED_BLOCK);
}
}

Expand Down Expand Up @@ -169,10 +197,20 @@ public String getDescription()
@Override
public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionManager functionManager)
{
return specializeArrayJoin(boundVariables.getTypeVariables(), functionManager, ImmutableList.of(false, false), METHOD_HANDLE);
return specializeArrayJoin(
boundVariables.getTypeVariables(),
functionManager,
ImmutableList.of(false, false),
METHOD_HANDLE_STACK,
METHOD_HANDLE_PROVIDED_BLOCK);
}

private static BuiltInScalarFunctionImplementation specializeArrayJoin(Map<String, Type> types, FunctionManager functionManager, List<Boolean> nullableArguments, MethodHandle methodHandle)
private static BuiltInScalarFunctionImplementation specializeArrayJoin(
Map<String, Type> types,
FunctionManager functionManager,
List<Boolean> nullableArguments,
MethodHandle methodHandleStack,
MethodHandle methodHandleProvidedBlock)
{
Type type = types.get("T");
List<ArgumentProperty> argumentProperties = nullableArguments.stream()
Expand All @@ -185,7 +223,7 @@ private static BuiltInScalarFunctionImplementation specializeArrayJoin(Map<Strin
return new BuiltInScalarFunctionImplementation(
false,
argumentProperties,
methodHandle.bindTo(null),
methodHandleStack.bindTo(null),
Optional.of(STATE_FACTORY));
}
else {
Expand Down Expand Up @@ -225,12 +263,22 @@ else if (elementType == Slice.class) {
cast = MethodHandles.dropArguments(cast, 1, Block.class);
cast = MethodHandles.foldArguments(cast, getter.bindTo(type));

MethodHandle target = MethodHandles.insertArguments(methodHandle, 0, cast);
MethodHandle targetStack = MethodHandles.insertArguments(methodHandleStack, 0, cast);
MethodHandle targetProvidedBlock = MethodHandles.insertArguments(methodHandleProvidedBlock, 0, cast);
return new BuiltInScalarFunctionImplementation(
false,
argumentProperties,
target,
Optional.of(STATE_FACTORY));
ImmutableList.of(
new ScalarImplementationChoice(
false,
argumentProperties,
ReturnPlaceConvention.STACK,
targetStack,
Optional.of(STATE_FACTORY)),
new ScalarImplementationChoice(
false,
argumentProperties,
PROVIDED_BLOCKBUILDER,
targetProvidedBlock,
Optional.empty())));
}
catch (PrestoException e) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Input type %s not supported", type), e);
Expand All @@ -239,18 +287,29 @@ else if (elementType == Slice.class) {
}

@UsedByGeneratedCode
public static Slice arrayJoin(
public static Slice arrayJoinStack(
MethodHandle castFunction,
Object state,
ConnectorSession session,
Block arrayBlock,
Slice delimiter)
{
return arrayJoin(castFunction, state, session, arrayBlock, delimiter, null);
return arrayJoinStack(castFunction, state, session, arrayBlock, delimiter, null);
}

@UsedByGeneratedCode
public static void arrayJoinProvidedBlock(
MethodHandle castFunction,
ConnectorSession session,
BlockBuilder blockBuilder,
Block arrayBlock,
Slice delimiter)
{
arrayJoinProvidedBlock(castFunction, session, blockBuilder, arrayBlock, delimiter, null);
}

@UsedByGeneratedCode
public static Slice arrayJoin(
public static Slice arrayJoinStack(
MethodHandle castFunction,
Object state,
ConnectorSession session,
Expand All @@ -262,9 +321,32 @@ public static Slice arrayJoin(
if (pageBuilder.isFull()) {
pageBuilder.reset();
}
int numElements = arrayBlock.getPositionCount();

BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0);

try {
arrayJoinProvidedBlock(castFunction, session, blockBuilder, arrayBlock, delimiter, nullReplacement);
}
catch (PrestoException e) {
// Restore pageBuilder into a consistent state
pageBuilder.declarePosition();
}

pageBuilder.declarePosition();
return VARCHAR.getSlice(blockBuilder, blockBuilder.getPositionCount() - 1);
}

@UsedByGeneratedCode
public static void arrayJoinProvidedBlock(
MethodHandle castFunction,
ConnectorSession session,
BlockBuilder blockBuilder,
Block arrayBlock,
Slice delimiter,
Slice nullReplacement)
{
int numElements = arrayBlock.getPositionCount();

for (int i = 0; i < numElements; i++) {
if (arrayBlock.isNull(i)) {
if (nullReplacement != null) {
Expand All @@ -280,9 +362,8 @@ public static Slice arrayJoin(
blockBuilder.writeBytes(slice, 0, slice.length());
}
catch (Throwable throwable) {
// Restore pageBuilder into a consistent state
// Restore blockBuilder into a consistent state
blockBuilder.closeEntry();
pageBuilder.declarePosition();
throw new PrestoException(GENERIC_INTERNAL_ERROR, "Error casting array element to VARCHAR", throwable);
}
}
Expand All @@ -293,7 +374,5 @@ public static Slice arrayJoin(
}

blockBuilder.closeEntry();
pageBuilder.declarePosition();
return VARCHAR.getSlice(blockBuilder, blockBuilder.getPositionCount() - 1);
}
}