Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino.operator.scalar;

import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.Description;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;
Expand All @@ -39,19 +38,22 @@ public static Block filterLong(
@SqlType("function(T, boolean)") LongToBooleanFunction function)
{
int positionCount = arrayBlock.getPositionCount();
BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount);
int[] positions = new int[positionCount];
int length = 0;
for (int position = 0; position < positionCount; position++) {
Long input = null;
if (!arrayBlock.isNull(position)) {
input = elementType.getLong(arrayBlock, position);
}

Boolean keep = function.apply(input);
if (TRUE.equals(keep)) {
elementType.appendTo(arrayBlock, position, resultBuilder);
}
positions[length] = position;
length += TRUE.equals(keep) ? 1 : 0;
}
if (positions.length == length) {
return arrayBlock;
}
return resultBuilder.build();
return arrayBlock.copyPositions(positions, 0, length);
}

@TypeParameter("T")
Expand All @@ -63,19 +65,22 @@ public static Block filterDouble(
@SqlType("function(T, boolean)") DoubleToBooleanFunction function)
{
int positionCount = arrayBlock.getPositionCount();
BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount);
int[] positions = new int[positionCount];
int length = 0;
for (int position = 0; position < positionCount; position++) {
Double input = null;
if (!arrayBlock.isNull(position)) {
input = elementType.getDouble(arrayBlock, position);
}

Boolean keep = function.apply(input);
if (TRUE.equals(keep)) {
elementType.appendTo(arrayBlock, position, resultBuilder);
}
positions[length] = position;
length += TRUE.equals(keep) ? 1 : 0;
}
if (positions.length == length) {
return arrayBlock;
}
return resultBuilder.build();
return arrayBlock.copyPositions(positions, 0, length);
}

@TypeParameter("T")
Expand All @@ -87,19 +92,22 @@ public static Block filterBoolean(
@SqlType("function(T, boolean)") BooleanToBooleanFunction function)
{
int positionCount = arrayBlock.getPositionCount();
BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount);
int[] positions = new int[positionCount];
int length = 0;
for (int position = 0; position < positionCount; position++) {
Boolean input = null;
if (!arrayBlock.isNull(position)) {
input = elementType.getBoolean(arrayBlock, position);
}

Boolean keep = function.apply(input);
if (TRUE.equals(keep)) {
elementType.appendTo(arrayBlock, position, resultBuilder);
}
positions[length] = position;
length += TRUE.equals(keep) ? 1 : 0;
}
if (positions.length == length) {
return arrayBlock;
}
return resultBuilder.build();
return arrayBlock.copyPositions(positions, 0, length);
}

@TypeParameter("T")
Expand All @@ -111,18 +119,21 @@ public static Block filterObject(
@SqlType("function(T, boolean)") ObjectToBooleanFunction function)
{
int positionCount = arrayBlock.getPositionCount();
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount);
int[] positions = new int[positionCount];
int length = 0;
for (int position = 0; position < positionCount; position++) {
Object input = null;
if (!arrayBlock.isNull(position)) {
input = elementType.getObject(arrayBlock, position);
}

Boolean keep = function.apply(input);
if (TRUE.equals(keep)) {
elementType.appendTo(arrayBlock, position, resultBuilder);
}
positions[length] = position;
length += TRUE.equals(keep) ? 1 : 0;
}
if (positions.length == length) {
return arrayBlock;
}
return resultBuilder.build();
return arrayBlock.copyPositions(positions, 0, length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.Signature;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.sql.gen.ExpressionCompiler;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.LambdaDefinitionExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.VariableReferenceExpression;
import io.trino.sql.tree.QualifiedName;
import io.trino.type.FunctionType;
Expand All @@ -58,19 +60,24 @@

import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.base.Verify.verify;
import static io.trino.block.BlockAssertions.createRandomBlockForType;
import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.operator.scalar.BenchmarkArrayFilter.ExactArrayFilterFunction.EXACT_ARRAY_FILTER_FUNCTION;
import static io.trino.operator.scalar.BenchmarkArrayFilter.ExactArrayFilterObjectFunction.EXACT_ARRAY_FILTER_OBJECT_FUNCTION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.OperatorType.LESS_THAN;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.TypeSignature.arrayType;
import static io.trino.spi.type.TypeSignature.functionType;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.relational.Expressions.constant;
import static io.trino.sql.relational.Expressions.field;
import static io.trino.sql.relational.SpecialForm.Form.DEREFERENCE;
import static io.trino.testing.TestingConnectorSession.SESSION;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.Boolean.TRUE;
Expand All @@ -88,9 +95,11 @@ public class BenchmarkArrayFilter
private static final int ARRAY_SIZE = 4;
private static final int NUM_TYPES = 1;
private static final List<Type> TYPES = ImmutableList.of(BIGINT);
private static final List<Type> ROW_TYPES = ImmutableList.of(RowType.anonymous(ImmutableList.of(BIGINT, DOUBLE)));

static {
verify(NUM_TYPES == TYPES.size());
verify(NUM_TYPES == ROW_TYPES.size());
}

@Benchmark
Expand All @@ -105,6 +114,18 @@ public List<Optional<Page>> benchmark(BenchmarkData data)
data.getPage()));
}

@Benchmark
@OperationsPerInvocation(POSITIONS * ARRAY_SIZE * NUM_TYPES)
public List<Optional<Page>> benchmarkObject(RowBenchmarkData data)
{
return ImmutableList.copyOf(
data.getPageProcessor().process(
SESSION,
new DriverYieldSignal(),
newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()),
data.getPage()));
}

@SuppressWarnings("FieldMayBeFinal")
@State(Scope.Thread)
public static class BenchmarkData
Expand Down Expand Up @@ -172,15 +193,85 @@ public Page getPage()
}
}

@SuppressWarnings("FieldMayBeFinal")
@State(Scope.Thread)
public static class RowBenchmarkData
{
@Param({"filter", "exact_filter"})
private String name = "filter";

private Page page;
private PageProcessor pageProcessor;

@Setup
public void setup()
{
TestingFunctionResolution functionResolution = new TestingFunctionResolution(InternalFunctionBundle.builder().function(EXACT_ARRAY_FILTER_OBJECT_FUNCTION).build());
ExpressionCompiler compiler = functionResolution.getExpressionCompiler();
ImmutableList.Builder<RowExpression> projectionsBuilder = ImmutableList.builder();
Block[] blocks = new Block[ROW_TYPES.size()];
for (int i = 0; i < ROW_TYPES.size(); i++) {
Type elementType = ROW_TYPES.get(i);
ArrayType arrayType = new ArrayType(elementType);
ResolvedFunction resolvedFunction = functionResolution.resolveFunction(
QualifiedName.of(name),
fromTypes(arrayType, new FunctionType(ROW_TYPES, BOOLEAN)));
ResolvedFunction lessThan = functionResolution.resolveOperator(LESS_THAN, ImmutableList.of(BIGINT, BIGINT));

projectionsBuilder.add(new CallExpression(resolvedFunction, ImmutableList.of(
field(0, arrayType),
new LambdaDefinitionExpression(
ImmutableList.of(elementType),
ImmutableList.of("x"),
new CallExpression(
lessThan,
ImmutableList.of(
constant(0L, BIGINT),
new SpecialForm(
DEREFERENCE,
BIGINT,
new VariableReferenceExpression("x", elementType),
constant(0, INTEGER))))))));
blocks[i] = createChannel(POSITIONS, arrayType);
}

ImmutableList<RowExpression> projections = projectionsBuilder.build();
pageProcessor = compiler.compilePageProcessor(Optional.empty(), projections).get();
page = new Page(blocks);
}

private static Block createChannel(int positionCount, ArrayType arrayType)
{
return createRandomBlockForType(arrayType, positionCount, 0.2F);
}

public PageProcessor getPageProcessor()
{
return pageProcessor;
}

public Page getPage()
{
return page;
}
}

public static void main(String[] args)
throws Exception
{
// assure the benchmarks are valid before running
BenchmarkData data = new BenchmarkData();
data.setup();
new BenchmarkArrayFilter().benchmark(data);
BenchmarkArrayFilter benchmarkArrayFilter = new BenchmarkArrayFilter();
benchmarkArrayFilter.benchmark(data);

Benchmarks.benchmark(BenchmarkArrayFilter.class).run();
RowBenchmarkData rowData = new RowBenchmarkData();
rowData.setup();
benchmarkArrayFilter.benchmarkObject(rowData);

Benchmarks.benchmark(BenchmarkArrayFilter.class)
.withOptions(optionsBuilder -> optionsBuilder.jvmArgs("-Xmx4g"))
.run();
}

public static final class ExactArrayFilterFunction
Expand Down Expand Up @@ -237,4 +328,59 @@ public static Block filter(Type type, Block block, MethodHandle function)
return resultBuilder.build();
}
}

public static final class ExactArrayFilterObjectFunction
extends SqlScalarFunction
{
public static final ExactArrayFilterObjectFunction EXACT_ARRAY_FILTER_OBJECT_FUNCTION = new ExactArrayFilterObjectFunction();

private static final MethodHandle METHOD_HANDLE = methodHandle(ExactArrayFilterObjectFunction.class, "filterObject", Type.class, Block.class, MethodHandle.class);

private ExactArrayFilterObjectFunction()
{
super(FunctionMetadata.scalarBuilder()
.signature(Signature.builder()
.name("exact_filter")
.typeVariable("T")
.returnType(arrayType(new TypeSignature("T")))
.argumentType(arrayType(new TypeSignature("T")))
.argumentType(functionType(new TypeSignature("T"), BOOLEAN.getTypeSignature()))
.build())
.nondeterministic()
.description("return array containing elements that match the given predicate")
.build());
}

@Override
protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature)
{
Type type = ((ArrayType) boundSignature.getReturnType()).getElementType();
return new ChoicesSpecializedSqlScalarFunction(
boundSignature,
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL, NEVER_NULL),
METHOD_HANDLE.bindTo(type));
}

public static Block filterObject(Type type, Block block, MethodHandle function)
{
int positionCount = block.getPositionCount();
BlockBuilder resultBuilder = type.createBlockBuilder(null, positionCount);
for (int position = 0; position < positionCount; position++) {
Object input = type.getObject(block, position);
Boolean keep;
try {
keep = (Boolean) function.invokeExact(input);
}
catch (Throwable t) {
throwIfUnchecked(t);
throw new RuntimeException(t);
}
if (TRUE.equals(keep)) {
type.appendTo(block, position, resultBuilder);
}
}
return resultBuilder.build();
}
}
}