diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java index 1d9750498460..1ef2b899b3d9 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java @@ -311,7 +311,7 @@ import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_FUNCTION; import static io.trino.operator.scalar.CastFromUnknownOperator.CAST_FROM_UNKNOWN; import static io.trino.operator.scalar.ConcatFunction.VARBINARY_CONCAT; -import static io.trino.operator.scalar.ConcatFunction.VARCHAR_CONCAT; +import static io.trino.operator.scalar.ConcatFunction.VARCHAR_CONCAT_FUNCTIONS; import static io.trino.operator.scalar.ConcatWsFunction.CONCAT_WS; import static io.trino.operator.scalar.ElementToArrayConcatFunction.ELEMENT_TO_ARRAY_CONCAT_FUNCTION; import static io.trino.operator.scalar.FormatFunction.FORMAT_FUNCTION; @@ -601,7 +601,8 @@ public FunctionRegistry( .functions(MAX_AGGREGATION, MIN_AGGREGATION, new MaxNAggregationFunction(blockTypeOperators), new MinNAggregationFunction(blockTypeOperators)) .function(COUNT_COLUMN) .functions(JSON_TO_ROW, JSON_STRING_TO_ROW, ROW_TO_ROW_CAST) - .functions(VARCHAR_CONCAT, VARBINARY_CONCAT) + .functions(VARCHAR_CONCAT_FUNCTIONS) + .function(VARBINARY_CONCAT) .function(CONCAT_WS) .function(DECIMAL_TO_DECIMAL_CAST) .function(castVarcharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries())) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatFunction.java index aa31af6a22ee..9cedac029ad4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatFunction.java @@ -25,45 +25,81 @@ import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.metadata.FunctionKind.SCALAR; +import static io.trino.metadata.Signature.longVariableExpression; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.type.TypeSignatureParameter.typeVariable; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.util.Reflection.methodHandle; import static java.lang.Math.addExact; +import static java.lang.String.format; import static java.util.Collections.nCopies; public final class ConcatFunction extends SqlScalarFunction { // TODO design new variadic functions binding mechanism that will allow to produce VARCHAR(x) where x < MAX_LENGTH. - public static final ConcatFunction VARCHAR_CONCAT = new ConcatFunction(VARCHAR.getTypeSignature(), "Concatenates given strings"); + public static final ConcatFunction[] VARCHAR_CONCAT_FUNCTIONS; - public static final ConcatFunction VARBINARY_CONCAT = new ConcatFunction(VARBINARY.getTypeSignature(), "concatenates given varbinary values"); + public static final ConcatFunction VARBINARY_CONCAT = new ConcatFunction(); + private static final int MIN_INPUT_VALUES = 1; private static final int MAX_INPUT_VALUES = 254; private static final int MAX_OUTPUT_LENGTH = DEFAULT_MAX_PAGE_SIZE_IN_BYTES; - private ConcatFunction(TypeSignature type, String description) + static { + VARCHAR_CONCAT_FUNCTIONS = new ConcatFunction[MAX_INPUT_VALUES - MIN_INPUT_VALUES + 1]; + + for (int arity = MIN_INPUT_VALUES; arity <= MAX_INPUT_VALUES; arity++) { + VARCHAR_CONCAT_FUNCTIONS[arity - MIN_INPUT_VALUES] = new ConcatFunction(arity); + } + } + + // Concat function for VARCHAR type + private ConcatFunction(int arity) { super(new FunctionMetadata( new Signature( "concat", ImmutableList.of(), + ImmutableList.of(longVariableExpression("L", toExpression(arity))), + new TypeSignature(VARCHAR.getTypeSignature().getBase(), typeVariable("L")), + toArgumentTypes(VARCHAR.getTypeSignature().getBase(), arity), + false), + false, + nCopies(arity, new FunctionArgumentDefinition(false)), + false, + true, + "Concatenates given strings", + SCALAR)); + } + + // Concat function for VARBINARY type + private ConcatFunction() + { + super(new FunctionMetadata( + new Signature( + "concat", ImmutableList.of(), - type, - ImmutableList.of(type), + ImmutableList.of(), + VARBINARY.getTypeSignature(), + ImmutableList.of(VARBINARY.getTypeSignature()), true), false, ImmutableList.of(new FunctionArgumentDefinition(false)), false, true, - description, + "concatenates given varbinary values", SCALAR)); } @@ -111,4 +147,16 @@ public static Slice concat(Slice[] values) return result; } + + private static String toExpression(int arity) + { + return format("min(2147483647, %s)", IntStream.rangeClosed(1, arity).mapToObj(number -> "S" + number).collect(Collectors.joining("+"))); + } + + private static List toArgumentTypes(String base, int arity) + { + return IntStream.rangeClosed(1, arity) + .mapToObj(number -> new TypeSignature(base, typeVariable("S" + number))) + .collect(toImmutableList()); + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestStringFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestStringFunctions.java index a4f8c2850ecf..187f73b05cd9 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestStringFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestStringFunctions.java @@ -107,21 +107,29 @@ public void testCodepoint() public void testConcat() { assertInvalidFunction("CONCAT('')", "There must be two or more concatenation arguments"); - assertFunction("CONCAT('hello', ' world')", VARCHAR, "hello world"); - assertFunction("CONCAT('', '')", VARCHAR, ""); - assertFunction("CONCAT('what', '')", VARCHAR, "what"); - assertFunction("CONCAT('', 'what')", VARCHAR, "what"); - assertFunction("CONCAT(CONCAT('this', ' is'), ' cool')", VARCHAR, "this is cool"); - assertFunction("CONCAT('this', CONCAT(' is', ' cool'))", VARCHAR, "this is cool"); + assertFunction("CONCAT('hello', ' world')", createVarcharType(11), "hello world"); + assertFunction("CONCAT('', '')", createVarcharType(0), ""); + assertFunction("CONCAT('what', '')", createVarcharType(4), "what"); + assertFunction("CONCAT('', 'what')", createVarcharType(4), "what"); + assertFunction("CONCAT(CONCAT('this', ' is'), ' cool')", createVarcharType(12), "this is cool"); + assertFunction("CONCAT('this', CONCAT(' is', ' cool'))", createVarcharType(12), "this is cool"); + assertFunction("CONCAT(CAST('max' AS VARCHAR(2147483647)), ' length')", VARCHAR, "max length"); + assertFunction("CONCAT('max', CAST(' length' AS VARCHAR(2147483647)))", VARCHAR, "max length"); + + assertFunction("'hello' || ' world'", createVarcharType(11), "hello world"); + assertFunction("'hello' || ' ' || 'world'", createVarcharType(11), "hello world"); + assertFunction("CHAR 'hello' || ' ' || 'world'", createCharType(11), "hello world"); + assertFunction("'hello' || CHAR ' ' || 'world'", createCharType(11), "hello world"); + assertFunction("'hello' || ' ' || CHAR 'world'", createCharType(11), "hello world"); // Test concat for non-ASCII - assertFunction("CONCAT('hello na\u00EFve', ' world')", VARCHAR, "hello na\u00EFve world"); - assertFunction("CONCAT('\uD801\uDC2D', 'end')", VARCHAR, "\uD801\uDC2Dend"); - assertFunction("CONCAT('\uD801\uDC2D', 'end', '\uD801\uDC2D')", VARCHAR, "\uD801\uDC2Dend\uD801\uDC2D"); - assertFunction("CONCAT(CONCAT('\u4FE1\u5FF5', ',\u7231'), ',\u5E0C\u671B')", VARCHAR, "\u4FE1\u5FF5,\u7231,\u5E0C\u671B"); + assertFunction("CONCAT('hello na\u00EFve', ' world')", createVarcharType(17), "hello na\u00EFve world"); + assertFunction("CONCAT('\uD801\uDC2D', 'end')", createVarcharType(4), "\uD801\uDC2Dend"); + assertFunction("CONCAT('\uD801\uDC2D', 'end', '\uD801\uDC2D')", createVarcharType(5), "\uD801\uDC2Dend\uD801\uDC2D"); + assertFunction("CONCAT(CONCAT('\u4FE1\u5FF5', ',\u7231'), ',\u5E0C\u671B')", createVarcharType(7), "\u4FE1\u5FF5,\u7231,\u5E0C\u671B"); // Test argument count limit - assertFunction("CONCAT(" + Joiner.on(", ").join(nCopies(127, "'x'")) + ")", VARCHAR, Joiner.on("").join(nCopies(127, "x"))); + assertFunction("CONCAT(" + Joiner.on(", ").join(nCopies(127, "'x'")) + ")", createVarcharType(127), Joiner.on("").join(nCopies(127, "x"))); assertInvalidFunction( "CONCAT(" + Joiner.on(", ").join(nCopies(128, "'x'")) + ")", TOO_MANY_ARGUMENTS,