diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RepeatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RepeatFunction.java index eb50b170be2ab..cb236183c73fa 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RepeatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RepeatFunction.java @@ -14,7 +14,7 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.common.block.Block; -import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.RunLengthEncodedBlock; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.function.Description; @@ -22,12 +22,10 @@ import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; -import io.airlift.slice.Slice; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.util.Failures.checkCondition; -import static java.lang.Math.toIntExact; @ScalarFunction(value = "repeat", calledOnNullInput = true) @Description("Repeat an element for a given number of times") @@ -43,9 +41,9 @@ public static Block repeat( @SqlNullable @SqlType("unknown") Boolean element, @SqlType(StandardTypes.INTEGER) long count) { + checkValidCount(count); checkCondition(element == null, INVALID_FUNCTION_ARGUMENT, "expect null values"); - BlockBuilder blockBuilder = createBlockBuilder(UNKNOWN, count); - return repeatNullValues(blockBuilder, count); + return RunLengthEncodedBlock.create(UNKNOWN, null, (int) count); } @TypeParameter("T") @@ -55,112 +53,19 @@ public static Block repeat( @SqlNullable @SqlType("T") Object element, @SqlType(StandardTypes.INTEGER) long count) { - BlockBuilder blockBuilder = createBlockBuilder(type, count); + checkValidCount(count); if (element == null) { - return repeatNullValues(blockBuilder, count); + return RunLengthEncodedBlock.create(type, null, (int) count); } - if (count > 0) { - type.writeObject(blockBuilder, element); - checkMaxSize(blockBuilder.getSizeInBytes(), count); - } - for (int i = 1; i < count; i++) { - type.writeObject(blockBuilder, element); - } - return blockBuilder.build(); - } - - @TypeParameter("T") - @SqlType("array(T)") - public static Block repeat( - @TypeParameter("T") Type type, - @SqlNullable @SqlType("T") Long element, - @SqlType(StandardTypes.INTEGER) long count) - { - BlockBuilder blockBuilder = createBlockBuilder(type, count); - if (element == null) { - return repeatNullValues(blockBuilder, count); - } - for (int i = 0; i < count; i++) { - type.writeLong(blockBuilder, element); - } - return blockBuilder.build(); - } - - @TypeParameter("T") - @SqlType("array(T)") - public static Block repeat( - @TypeParameter("T") Type type, - @SqlNullable @SqlType("T") Slice element, - @SqlType(StandardTypes.INTEGER) long count) - { - BlockBuilder blockBuilder = createBlockBuilder(type, count); - if (element == null) { - return repeatNullValues(blockBuilder, count); - } - if (count > 0) { - type.writeSlice(blockBuilder, element); - checkMaxSize(blockBuilder.getSizeInBytes(), count); - } - for (int i = 1; i < count; i++) { - type.writeSlice(blockBuilder, element); - } - return blockBuilder.build(); - } - - @TypeParameter("T") - @SqlType("array(T)") - public static Block repeat( - @TypeParameter("T") Type type, - @SqlNullable @SqlType("T") Boolean element, - @SqlType(StandardTypes.INTEGER) long count) - { - BlockBuilder blockBuilder = createBlockBuilder(type, count); - if (element == null) { - return repeatNullValues(blockBuilder, count); - } - for (int i = 0; i < count; i++) { - type.writeBoolean(blockBuilder, element); - } - return blockBuilder.build(); - } - - @TypeParameter("T") - @SqlType("array(T)") - public static Block repeat( - @TypeParameter("T") Type type, - @SqlNullable @SqlType("T") Double element, - @SqlType(StandardTypes.INTEGER) long count) - { - BlockBuilder blockBuilder = createBlockBuilder(type, count); - if (element == null) { - return repeatNullValues(blockBuilder, count); - } - for (int i = 0; i < count; i++) { - type.writeDouble(blockBuilder, element); - } - return blockBuilder.build(); + Block result = RunLengthEncodedBlock.create(type, element, (int) count); + checkCondition(result.getSizeInBytes() < MAX_SIZE_IN_BYTES, INVALID_FUNCTION_ARGUMENT, + "result of repeat function must not take more than 1000000 bytes"); + return result; } - private static BlockBuilder createBlockBuilder(Type type, long count) + private static void checkValidCount(long count) { checkCondition(count <= MAX_RESULT_ENTRIES, INVALID_FUNCTION_ARGUMENT, "count argument of repeat function must be less than or equal to 10000"); checkCondition(count >= 0, INVALID_FUNCTION_ARGUMENT, "count argument of repeat function must be greater than or equal to 0"); - return type.createBlockBuilder(null, toIntExact(count)); - } - - private static Block repeatNullValues(BlockBuilder blockBuilder, long count) - { - for (int i = 0; i < count; i++) { - blockBuilder.appendNull(); - } - return blockBuilder.build(); - } - - private static void checkMaxSize(long bytes, long count) - { - checkCondition( - bytes <= (MAX_SIZE_IN_BYTES + count) / count, - INVALID_FUNCTION_ARGUMENT, - "result of repeat function must not take more than 1000000 bytes"); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java index af50846d46cb4..d25bde5ba8f73 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java @@ -19,6 +19,7 @@ import com.facebook.presto.common.type.Decimals; import com.facebook.presto.common.type.SqlDecimal; import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.FunctionListBuilder; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.SqlScalarFunction; @@ -91,6 +92,11 @@ public final void destroyTestFunctions() functionAssertions = null; } + public FunctionAndTypeManager getFunctionAndTypeManager() + { + return functionAssertions.getFunctionAndTypeManager(); + } + protected void assertFunction(String projection, Type expectedType, Object expected) { functionAssertions.assertFunction(projection, expectedType, expected); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkRepeatFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkRepeatFunction.java new file mode 100644 index 0000000000000..f18b1b3eb2a7b --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkRepeatFunction.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.operator.DriverYieldSignal; +import com.facebook.presto.operator.project.PageProcessor; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.gen.ExpressionCompiler; +import com.facebook.presto.sql.gen.PageFunctionCompiler; +import com.google.common.collect.ImmutableList; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; +import org.openjdk.jmh.runner.options.WarmupMode; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.testing.TestingConnectorSession.SESSION; + +@SuppressWarnings("MethodMayBeStatic") +@State(Scope.Thread) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Fork(2) +@Warmup(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkRepeatFunction +{ + private static final int POSITIONS = 1000; + + public static void main(String[] args) + throws Throwable + { + // assure the benchmarks are valid before running + BenchmarkData data = new BenchmarkData(); + data.setup(); + new BenchmarkRepeatFunction().benchmark(data); + + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .warmupMode(WarmupMode.INDI) + .include(".*" + BenchmarkRepeatFunction.class.getSimpleName() + ".*") + .build(); + new Runner(options).run(); + } + + @Benchmark + @OperationsPerInvocation(POSITIONS) + public List> benchmark(BenchmarkData data) + { + return ImmutableList.copyOf( + data.getPageProcessor().process( + SESSION.getSqlFunctionProperties(), + new DriverYieldSignal(), + newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), + data.getPage())); + } + + @SuppressWarnings("FieldMayBeFinal") + @State(Scope.Thread) + public static class BenchmarkData + { + private String name = "repeat"; + + @Param({"10", "100", "1000"}) + private int repeatArgument = 100; + + private Page page; + private PageProcessor pageProcessor; + + @Setup + public void setup() + { + MetadataManager metadata = createTestMetadataManager(); + ExpressionCompiler compiler = new ExpressionCompiler(metadata, new PageFunctionCompiler(metadata, 0)); + + Type inputType = INTEGER; + ConstantExpression inputValue = new ConstantExpression((long) 2, inputType); + ConstantExpression repeatNum = new ConstantExpression((long) repeatArgument, INTEGER); + + ImmutableList.Builder projectionsBuilder = ImmutableList.builder(); + + FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(name, fromTypes(inputType, INTEGER)); + projectionsBuilder.add(new CallExpression( + name, + functionHandle, + new ArrayType(inputType), + ImmutableList.of(inputValue, repeatNum))); + + ImmutableList projections = projectionsBuilder.build(); + pageProcessor = compiler.compilePageProcessor(SESSION.getSqlFunctionProperties(), Optional.empty(), projections).get(); + page = new Page(10000); + } + + public PageProcessor getPageProcessor() + { + return pageProcessor; + } + + public Page getPage() + { + return page; + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java index a2f33f97324dd..e34e8c179792c 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java @@ -80,6 +80,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.AMBIGUOUS_FUNCTION_CALL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TYPE_MISMATCH; +import static com.facebook.presto.sql.planner.PlannerUtils.createMapType; import static com.facebook.presto.testing.DateTimeTestingUtils.sqlTimestampOf; import static com.facebook.presto.util.StructuralTestUtil.appendToBlockBuilder; import static com.facebook.presto.util.StructuralTestUtil.arrayBlockOf; @@ -92,6 +93,7 @@ import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.Collections.emptyList; +import static java.util.Collections.nCopies; import static java.util.Collections.singletonList; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -1685,6 +1687,16 @@ public void testRepeat() assertFunction("REPEAT(true, 1)", new ArrayType(BOOLEAN), ImmutableList.of(true)); assertFunction("REPEAT(0.5E0, 4)", new ArrayType(DOUBLE), ImmutableList.of(0.5, 0.5, 0.5, 0.5)); assertFunction("REPEAT(array[1], 4)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1), ImmutableList.of(1), ImmutableList.of(1), ImmutableList.of(1))); + assertFunction("repeat(cast(1 as integer), 10)", new ArrayType(INTEGER), nCopies(10, 1)); + assertFunction("repeat(cast(1 as integer), 0)", new ArrayType(INTEGER), nCopies(0, 1)); + assertFunction("repeat(cast(1 as bigint), 10)", new ArrayType(BIGINT), nCopies(10, (long) 1)); + assertFunction("repeat(cast('ab' as varchar), 10)", new ArrayType(VARCHAR), nCopies(10, "ab")); + assertFunction("repeat(array[cast(2 as bigint)], 10)", new ArrayType(new ArrayType(BIGINT)), nCopies(10, ImmutableList.of((long) 2))); + assertFunction("repeat(array[cast(2 as bigint), 3], 10)", new ArrayType(new ArrayType(BIGINT)), nCopies(10, ImmutableList.of((long) 2, (long) 3))); + assertFunction("repeat(array[cast(2 as integer)], 10)", new ArrayType(new ArrayType(INTEGER)), nCopies(10, ImmutableList.of(2))); + assertFunction("repeat(map(array[cast(2 as integer)], array[cast('ab' as varchar)]), 10)", new ArrayType(createMapType(getFunctionAndTypeManager(), INTEGER, VARCHAR)), nCopies(10, ImmutableMap.of(2, "ab"))); + assertFunction("REPEAT('loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooongvarchar', 9999)", new ArrayType(createVarcharType(108)), nCopies(9999, "loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooongvarchar")); + assertFunction("REPEAT(array[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 9999)", new ArrayType(new ArrayType(INTEGER)), nCopies(9999, ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20))); // null values assertFunction("REPEAT(null, 4)", new ArrayType(UNKNOWN), asList(null, null, null, null)); @@ -1693,6 +1705,7 @@ public void testRepeat() assertFunction("REPEAT(cast(null as varchar), 4)", new ArrayType(VARCHAR), asList(null, null, null, null)); assertFunction("REPEAT(cast(null as boolean), 4)", new ArrayType(BOOLEAN), asList(null, null, null, null)); assertFunction("REPEAT(cast(null as array(boolean)), 4)", new ArrayType(new ArrayType(BOOLEAN)), asList(null, null, null, null)); + assertFunction("repeat(cast(null as varchar), 10)", new ArrayType(VARCHAR), nCopies(10, null)); // 0 counts assertFunction("REPEAT(cast(null as bigint), 0)", new ArrayType(BIGINT), ImmutableList.of()); @@ -1705,8 +1718,6 @@ public void testRepeat() // illegal inputs assertInvalidFunction("REPEAT(2, -1)", INVALID_FUNCTION_ARGUMENT); assertInvalidFunction("REPEAT(1, 1000000)", INVALID_FUNCTION_ARGUMENT); - assertInvalidFunction("REPEAT('loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooongvarchar', 9999)", INVALID_FUNCTION_ARGUMENT); - assertInvalidFunction("REPEAT(array[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 9999)", INVALID_FUNCTION_ARGUMENT); } @Test diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 5e285b3bb424d..f68617124f9bd 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -7463,4 +7463,11 @@ public void testMergeKHyperLogLog() "select cardinality(khyperloglog_agg(v1, v2)), uniqueness_distribution(khyperloglog_agg(v1, v2)) from (values (1, 1, 2, 3), (1, 1, 4, 0), (1, 2, 90, 20), (1, 2, 87, 1), (2, 1, 11, 30), (2, 1, 11, 11), " + "(2, 2, 9, 1), (2, 2, 87, 2)) t(k1, k2, v1, v2)"); } + + @Test + public void testRepeat() + { + assertQuery("select repeat(k1, k2), repeat(k1, 5), repeat(3, k2) from (values (3, 2), (5, 4), (2, 4))t(k1, k2)", + "values (array[3, 3], array[3,3,3,3,3], array[3, 3]), (array[5, 5, 5, 5], array[5, 5, 5, 5, 5], array[3, 3, 3, 3]), (array[2, 2, 2, 2], array[2, 2, 2, 2, 2], array[3, 3, 3, 3])"); + } }