diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayFilterFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayFilterFunction.java index 0fc48db19d6d..192ec8394285 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayFilterFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayFilterFunction.java @@ -14,7 +14,6 @@ package io.trino.operator.scalar; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.Description; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlType; @@ -39,7 +38,8 @@ public static Block filterLong( @SqlType("function(T, boolean)") LongToBooleanFunction function) { int positionCount = arrayBlock.getPositionCount(); - BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount); + int[] positions = new int[positionCount]; + int length = 0; for (int position = 0; position < positionCount; position++) { Long input = null; if (!arrayBlock.isNull(position)) { @@ -47,11 +47,13 @@ public static Block filterLong( } Boolean keep = function.apply(input); - if (TRUE.equals(keep)) { - elementType.appendTo(arrayBlock, position, resultBuilder); - } + positions[length] = position; + length += TRUE.equals(keep) ? 1 : 0; + } + if (positions.length == length) { + return arrayBlock; } - return resultBuilder.build(); + return arrayBlock.copyPositions(positions, 0, length); } @TypeParameter("T") @@ -63,7 +65,8 @@ public static Block filterDouble( @SqlType("function(T, boolean)") DoubleToBooleanFunction function) { int positionCount = arrayBlock.getPositionCount(); - BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount); + int[] positions = new int[positionCount]; + int length = 0; for (int position = 0; position < positionCount; position++) { Double input = null; if (!arrayBlock.isNull(position)) { @@ -71,11 +74,13 @@ public static Block filterDouble( } Boolean keep = function.apply(input); - if (TRUE.equals(keep)) { - elementType.appendTo(arrayBlock, position, resultBuilder); - } + positions[length] = position; + length += TRUE.equals(keep) ? 1 : 0; + } + if (positions.length == length) { + return arrayBlock; } - return resultBuilder.build(); + return arrayBlock.copyPositions(positions, 0, length); } @TypeParameter("T") @@ -87,7 +92,8 @@ public static Block filterBoolean( @SqlType("function(T, boolean)") BooleanToBooleanFunction function) { int positionCount = arrayBlock.getPositionCount(); - BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount); + int[] positions = new int[positionCount]; + int length = 0; for (int position = 0; position < positionCount; position++) { Boolean input = null; if (!arrayBlock.isNull(position)) { @@ -95,11 +101,13 @@ public static Block filterBoolean( } Boolean keep = function.apply(input); - if (TRUE.equals(keep)) { - elementType.appendTo(arrayBlock, position, resultBuilder); - } + positions[length] = position; + length += TRUE.equals(keep) ? 1 : 0; + } + if (positions.length == length) { + return arrayBlock; } - return resultBuilder.build(); + return arrayBlock.copyPositions(positions, 0, length); } @TypeParameter("T") @@ -111,7 +119,8 @@ public static Block filterObject( @SqlType("function(T, boolean)") ObjectToBooleanFunction function) { int positionCount = arrayBlock.getPositionCount(); - BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount); + int[] positions = new int[positionCount]; + int length = 0; for (int position = 0; position < positionCount; position++) { Object input = null; if (!arrayBlock.isNull(position)) { @@ -119,10 +128,12 @@ public static Block filterObject( } Boolean keep = function.apply(input); - if (TRUE.equals(keep)) { - elementType.appendTo(arrayBlock, position, resultBuilder); - } + positions[length] = position; + length += TRUE.equals(keep) ? 1 : 0; + } + if (positions.length == length) { + return arrayBlock; } - return resultBuilder.build(); + return arrayBlock.copyPositions(positions, 0, length); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java index 6628df298941..f1ad4b805550 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java @@ -28,12 +28,14 @@ import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; +import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.sql.gen.ExpressionCompiler; import io.trino.sql.relational.CallExpression; import io.trino.sql.relational.LambdaDefinitionExpression; import io.trino.sql.relational.RowExpression; +import io.trino.sql.relational.SpecialForm; import io.trino.sql.relational.VariableReferenceExpression; import io.trino.sql.tree.QualifiedName; import io.trino.type.FunctionType; @@ -58,19 +60,24 @@ import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.base.Verify.verify; +import static io.trino.block.BlockAssertions.createRandomBlockForType; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.operator.scalar.BenchmarkArrayFilter.ExactArrayFilterFunction.EXACT_ARRAY_FILTER_FUNCTION; +import static io.trino.operator.scalar.BenchmarkArrayFilter.ExactArrayFilterObjectFunction.EXACT_ARRAY_FILTER_OBJECT_FUNCTION; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.OperatorType.LESS_THAN; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TypeSignature.arrayType; import static io.trino.spi.type.TypeSignature.functionType; import static io.trino.spi.type.TypeUtils.readNativeValue; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.relational.Expressions.constant; import static io.trino.sql.relational.Expressions.field; +import static io.trino.sql.relational.SpecialForm.Form.DEREFERENCE; import static io.trino.testing.TestingConnectorSession.SESSION; import static io.trino.util.Reflection.methodHandle; import static java.lang.Boolean.TRUE; @@ -88,9 +95,11 @@ public class BenchmarkArrayFilter private static final int ARRAY_SIZE = 4; private static final int NUM_TYPES = 1; private static final List TYPES = ImmutableList.of(BIGINT); + private static final List ROW_TYPES = ImmutableList.of(RowType.anonymous(ImmutableList.of(BIGINT, DOUBLE))); static { verify(NUM_TYPES == TYPES.size()); + verify(NUM_TYPES == ROW_TYPES.size()); } @Benchmark @@ -105,6 +114,18 @@ public List> benchmark(BenchmarkData data) data.getPage())); } + @Benchmark + @OperationsPerInvocation(POSITIONS * ARRAY_SIZE * NUM_TYPES) + public List> benchmarkObject(RowBenchmarkData data) + { + return ImmutableList.copyOf( + data.getPageProcessor().process( + SESSION, + new DriverYieldSignal(), + newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), + data.getPage())); + } + @SuppressWarnings("FieldMayBeFinal") @State(Scope.Thread) public static class BenchmarkData @@ -172,15 +193,85 @@ public Page getPage() } } + @SuppressWarnings("FieldMayBeFinal") + @State(Scope.Thread) + public static class RowBenchmarkData + { + @Param({"filter", "exact_filter"}) + private String name = "filter"; + + private Page page; + private PageProcessor pageProcessor; + + @Setup + public void setup() + { + TestingFunctionResolution functionResolution = new TestingFunctionResolution(InternalFunctionBundle.builder().function(EXACT_ARRAY_FILTER_OBJECT_FUNCTION).build()); + ExpressionCompiler compiler = functionResolution.getExpressionCompiler(); + ImmutableList.Builder projectionsBuilder = ImmutableList.builder(); + Block[] blocks = new Block[ROW_TYPES.size()]; + for (int i = 0; i < ROW_TYPES.size(); i++) { + Type elementType = ROW_TYPES.get(i); + ArrayType arrayType = new ArrayType(elementType); + ResolvedFunction resolvedFunction = functionResolution.resolveFunction( + QualifiedName.of(name), + fromTypes(arrayType, new FunctionType(ROW_TYPES, BOOLEAN))); + ResolvedFunction lessThan = functionResolution.resolveOperator(LESS_THAN, ImmutableList.of(BIGINT, BIGINT)); + + projectionsBuilder.add(new CallExpression(resolvedFunction, ImmutableList.of( + field(0, arrayType), + new LambdaDefinitionExpression( + ImmutableList.of(elementType), + ImmutableList.of("x"), + new CallExpression( + lessThan, + ImmutableList.of( + constant(0L, BIGINT), + new SpecialForm( + DEREFERENCE, + BIGINT, + new VariableReferenceExpression("x", elementType), + constant(0, INTEGER)))))))); + blocks[i] = createChannel(POSITIONS, arrayType); + } + + ImmutableList projections = projectionsBuilder.build(); + pageProcessor = compiler.compilePageProcessor(Optional.empty(), projections).get(); + page = new Page(blocks); + } + + private static Block createChannel(int positionCount, ArrayType arrayType) + { + return createRandomBlockForType(arrayType, positionCount, 0.2F); + } + + public PageProcessor getPageProcessor() + { + return pageProcessor; + } + + public Page getPage() + { + return page; + } + } + public static void main(String[] args) throws Exception { // assure the benchmarks are valid before running BenchmarkData data = new BenchmarkData(); data.setup(); - new BenchmarkArrayFilter().benchmark(data); + BenchmarkArrayFilter benchmarkArrayFilter = new BenchmarkArrayFilter(); + benchmarkArrayFilter.benchmark(data); - Benchmarks.benchmark(BenchmarkArrayFilter.class).run(); + RowBenchmarkData rowData = new RowBenchmarkData(); + rowData.setup(); + benchmarkArrayFilter.benchmarkObject(rowData); + + Benchmarks.benchmark(BenchmarkArrayFilter.class) + .withOptions(optionsBuilder -> optionsBuilder.jvmArgs("-Xmx4g")) + .run(); } public static final class ExactArrayFilterFunction @@ -237,4 +328,59 @@ public static Block filter(Type type, Block block, MethodHandle function) return resultBuilder.build(); } } + + public static final class ExactArrayFilterObjectFunction + extends SqlScalarFunction + { + public static final ExactArrayFilterObjectFunction EXACT_ARRAY_FILTER_OBJECT_FUNCTION = new ExactArrayFilterObjectFunction(); + + private static final MethodHandle METHOD_HANDLE = methodHandle(ExactArrayFilterObjectFunction.class, "filterObject", Type.class, Block.class, MethodHandle.class); + + private ExactArrayFilterObjectFunction() + { + super(FunctionMetadata.scalarBuilder() + .signature(Signature.builder() + .name("exact_filter") + .typeVariable("T") + .returnType(arrayType(new TypeSignature("T"))) + .argumentType(arrayType(new TypeSignature("T"))) + .argumentType(functionType(new TypeSignature("T"), BOOLEAN.getTypeSignature())) + .build()) + .nondeterministic() + .description("return array containing elements that match the given predicate") + .build()); + } + + @Override + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + { + Type type = ((ArrayType) boundSignature.getReturnType()).getElementType(); + return new ChoicesSpecializedSqlScalarFunction( + boundSignature, + FAIL_ON_NULL, + ImmutableList.of(NEVER_NULL, NEVER_NULL), + METHOD_HANDLE.bindTo(type)); + } + + public static Block filterObject(Type type, Block block, MethodHandle function) + { + int positionCount = block.getPositionCount(); + BlockBuilder resultBuilder = type.createBlockBuilder(null, positionCount); + for (int position = 0; position < positionCount; position++) { + Object input = type.getObject(block, position); + Boolean keep; + try { + keep = (Boolean) function.invokeExact(input); + } + catch (Throwable t) { + throwIfUnchecked(t); + throw new RuntimeException(t); + } + if (TRUE.equals(keep)) { + type.appendTo(block, position, resultBuilder); + } + } + return resultBuilder.build(); + } + } }