diff --git a/core/trino-main/src/main/java/io/trino/execution/CallTask.java b/core/trino-main/src/main/java/io/trino/execution/CallTask.java index 5aa5aa9f99fa..36f3a5ea68ce 100644 --- a/core/trino-main/src/main/java/io/trino/execution/CallTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/CallTask.java @@ -127,7 +127,7 @@ public ListenableFuture execute( for (int i = 0; i < call.getArguments().size(); i++) { CallArgument argument = call.getArguments().get(i); if (argument.getName().isPresent()) { - String name = argument.getName().get(); + String name = argument.getName().get().getCanonicalValue(); if (names.put(name, argument) != null) { throw semanticException(INVALID_ARGUMENTS, argument, "Duplicate procedure argument: %s", name); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 4c4b7cfa0209..1a2937755c5d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -1127,7 +1127,7 @@ private Map processTableExecuteArguments(TableExecute node, if (anyNamed) { // all properties named for (CallArgument argument : arguments) { - if (argumentsMap.put(argument.getName().get(), argument.getValue()) != null) { + if (argumentsMap.put(argument.getName().get().getCanonicalValue(), argument.getValue()) != null) { throw semanticException(DUPLICATE_PROPERTY, argument, "Duplicate named argument: %s", argument.getName()); } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 1c9aba0b6821..ccfb20151d23 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -2698,7 +2698,7 @@ public Node visitPositionalArgument(SqlBaseParser.PositionalArgumentContext cont @Override public Node visitNamedArgument(SqlBaseParser.NamedArgumentContext context) { - return new CallArgument(getLocation(context), context.identifier().getText(), (Expression) visit(context.expression())); + return new CallArgument(getLocation(context), (Identifier) visit(context.identifier()), (Expression) visit(context.expression())); } @Override diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/CallArgument.java b/core/trino-parser/src/main/java/io/trino/sql/tree/CallArgument.java index 3fed6725216a..85467ad8e8e6 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/CallArgument.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/CallArgument.java @@ -25,7 +25,7 @@ public final class CallArgument extends Node { - private final Optional name; + private final Optional name; private final Expression value; public CallArgument(Expression value) @@ -38,24 +38,24 @@ public CallArgument(NodeLocation location, Expression value) this(Optional.of(location), Optional.empty(), value); } - public CallArgument(String name, Expression value) + public CallArgument(Identifier name, Expression value) { this(Optional.empty(), Optional.of(name), value); } - public CallArgument(NodeLocation location, String name, Expression value) + public CallArgument(NodeLocation location, Identifier name, Expression value) { this(Optional.of(location), Optional.of(name), value); } - public CallArgument(Optional location, Optional name, Expression value) + public CallArgument(Optional location, Optional name, Expression value) { super(location); this.name = requireNonNull(name, "name is null"); this.value = requireNonNull(value, "value is null"); } - public Optional getName() + public Optional getName() { return name; } diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 945bd69b53ee..fc449210b4dd 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -1909,8 +1909,8 @@ public void testTableExecute() table, procedure, ImmutableList.of( - new CallArgument("bah", new LongLiteral("1")), - new CallArgument("wuh", new StringLiteral("clap"))), + new CallArgument(identifier("bah"), new LongLiteral("1")), + new CallArgument(identifier("wuh"), new StringLiteral("clap"))), Optional.of( new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new Identifier("age"), @@ -2577,8 +2577,8 @@ public void testCall() assertStatement("CALL foo()", new Call(QualifiedName.of("foo"), ImmutableList.of())); assertStatement("CALL foo(123, a => 1, b => 'go', 456)", new Call(QualifiedName.of("foo"), ImmutableList.of( new CallArgument(new LongLiteral("123")), - new CallArgument("a", new LongLiteral("1")), - new CallArgument("b", new StringLiteral("go")), + new CallArgument(identifier("a"), new LongLiteral("1")), + new CallArgument(identifier("b"), new StringLiteral("go")), new CallArgument(new LongLiteral("456"))))); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/TestingProcedures.java b/testing/trino-testing/src/main/java/io/trino/testing/TestingProcedures.java index bf37ee3442ad..45d0ee4e1bba 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/TestingProcedures.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/TestingProcedures.java @@ -127,6 +127,12 @@ public void error() throw new RuntimeException("test error from procedure"); } + @UsedByGeneratedCode + public void names(ConnectorSession session, String x, String y, String z, String v) + { + tester.recordCalled("names", x, y, z, v); + } + public List getProcedures(String schema) { return ImmutableList.builder() @@ -164,6 +170,11 @@ public List getProcedures(String schema) new Argument("v", VARCHAR, false, "v default")))) .add(procedure(schema, "test_exception", "exception", ImmutableList.of())) .add(procedure(schema, "test_error", "error", ImmutableList.of())) + .add(procedure(schema, "test_argument_names", "names", ImmutableList.of( + new Argument("lower", VARCHAR, false, "a"), + new Argument("UPPER", VARCHAR, false, "b"), + new Argument("MixeD", VARCHAR, false, "c"), + new Argument("with space", VARCHAR, false, "d")))) .build(); } diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestProcedureCall.java b/testing/trino-tests/src/test/java/io/trino/tests/TestProcedureCall.java index 1f40918b21e5..d1c15fb1a494 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestProcedureCall.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestProcedureCall.java @@ -92,8 +92,8 @@ public void testProcedureCall() assertCall("CALL test_args(123, 4.5, 'hello', true)", "args", 123L, 4.5, "hello", true); assertCall("CALL test_args(-5, nan(), 'bye', false)", "args", -5L, Double.NaN, "bye", false); assertCall("CALL test_args(3, 88, 'coerce', true)", "args", 3L, 88.0, "coerce", true); - assertCall("CALL test_args(x => 123, y => 4.5, z => 'hello', q => true)", "args", 123L, 4.5, "hello", true); - assertCall("CALL test_args(q => true, z => 'hello', y => 4.5, x => 123)", "args", 123L, 4.5, "hello", true); + assertCall("CALL test_args(\"x\" => 123, \"y\" => 4.5, \"z\" => 'hello', \"q\" => true)", "args", 123L, 4.5, "hello", true); + assertCall("CALL test_args(\"q\" => true, \"z\" => 'hello', \"y\" => 4.5, \"x\" => 123)", "args", 123L, 4.5, "hello", true); assertCall("CALL test_nulls(123, null)", "nulls", 123L, null); assertCall("CALL test_nulls(null, 'apple')", "nulls", null, "apple"); @@ -117,10 +117,10 @@ public void testProcedureCall() assertCallFails("CALL test_simple(123)", "line 1:1: Too many arguments for procedure"); assertCallFails("CALL test_args(123, 4.5, 'hello')", "line 1:1: Required procedure argument 'q' is missing"); - assertCallFails("CALL test_args(x => 123, y => 4.5, q => true)", "line 1:1: Required procedure argument 'z' is missing"); - assertCallFails("CALL test_args(123, 4.5, 'hello', q => true)", "line 1:1: Named and positional arguments cannot be mixed"); - assertCallFails("CALL test_args(x => 3, x => 4)", "line 1:24: Duplicate procedure argument: x"); - assertCallFails("CALL test_args(t => 404)", "line 1:16: Unknown argument name: t"); + assertCallFails("CALL test_args(\"x\" => 123, \"y\" => 4.5, \"q\" => true)", "line 1:1: Required procedure argument 'z' is missing"); + assertCallFails("CALL test_args(123, 4.5, 'hello', \"q\" => true)", "line 1:1: Named and positional arguments cannot be mixed"); + assertCallFails("CALL test_args(\"x\" => 3, \"x\" => 4)", "line 1:26: Duplicate procedure argument: x"); + assertCallFails("CALL test_args(\"t\" => 404)", "line 1:16: Unknown argument name: t"); assertCallFails("CALL test_nulls('hello', null)", "line 1:17: Cannot cast type varchar(5) to bigint"); assertCallFails("CALL test_nulls(null, 123)", "line 1:23: Cannot cast type integer to varchar"); } @@ -128,32 +128,47 @@ public void testProcedureCall() @Test public void testProcedureCallWithOptionals() { - // test_optionals(x => Optional['hello']) - // test_optionals2(x, y => Optional['world]) - // test_optionals3(x => Optional['this'], y => Optional['is'], z => Optional['default']) - // test_optionals4(x, y, z => Optional['z default'], v => Optional['v default']) + // test_optionals(\"x\" => Optional['hello']) + // test_optionals2(x, \"y\" => Optional['world]) + // test_optionals3(\"x\" => Optional['this'], \"y\" => Optional['is'], \"z\" => Optional['default']) + // test_optionals4(x, y, \"z\" => Optional['z default'], \"v\" => Optional['v default']) assertCall("CALL test_optionals()", "optionals", "hello"); - assertCall("CALL test_optionals(x => 'x')", "optionals", "x"); - assertCall("CALL test_optionals2(x => 'ab')", "optionals2", "ab", "world"); + assertCall("CALL test_optionals(\"x\" => 'x')", "optionals", "x"); + assertCall("CALL test_optionals2(\"x\" => 'ab')", "optionals2", "ab", "world"); assertCall("CALL test_optionals2('ab')", "optionals2", "ab", "world"); - assertCall("CALL test_optionals2(x => 'ab', y => 'cd')", "optionals2", "ab", "cd"); - assertCall("CALL test_optionals2(y => 'cd', x => 'ab')", "optionals2", "ab", "cd"); + assertCall("CALL test_optionals2(\"x\" => 'ab', \"y\" => 'cd')", "optionals2", "ab", "cd"); + assertCall("CALL test_optionals2(\"y\" => 'cd', \"x\" => 'ab')", "optionals2", "ab", "cd"); assertCall("CALL test_optionals2('ab', 'cd')", "optionals2", "ab", "cd"); - assertCall("CALL test_optionals3(x => 'ab', z => 'cd')", "optionals3", "ab", "is", "cd"); + assertCall("CALL test_optionals3(\"x\" => 'ab', \"z\" => 'cd')", "optionals3", "ab", "is", "cd"); assertCall("CALL test_optionals3('ab', 'cd', 'ef')", "optionals3", "ab", "cd", "ef"); assertCall("CALL test_optionals3('ab', 'cd')", "optionals3", "ab", "cd", "default"); assertCall("CALL test_optionals3('ab')", "optionals3", "ab", "is", "default"); - assertCall("CALL test_optionals3(y => 'ab', z => 'cd')", "optionals3", "this", "ab", "cd"); - assertCall("CALL test_optionals3(z => 'cd')", "optionals3", "this", "is", "cd"); + assertCall("CALL test_optionals3(\"y\" => 'ab', \"z\" => 'cd')", "optionals3", "this", "ab", "cd"); + assertCall("CALL test_optionals3(\"z\" => 'cd')", "optionals3", "this", "is", "cd"); assertCall("CALL test_optionals4('a', 'b')", "optionals4", "a", "b", "z default", "v default"); - assertCall("CALL test_optionals4(x => 'x val', y => 'y val')", "optionals4", "x val", "y val", "z default", "v default"); - assertCall("CALL test_optionals4(z => 'z val', v => 'v val', x => 'x val', y => 'y val')", "optionals4", "x val", "y val", "z val", "v val"); - assertCall("CALL test_optionals4(v => 'v val', x => 'x val', y => 'y val', z => 'z val')", "optionals4", "x val", "y val", "z val", "v val"); + assertCall("CALL test_optionals4(\"x\" => 'x val', \"y\" => 'y val')", "optionals4", "x val", "y val", "z default", "v default"); + assertCall("CALL test_optionals4(\"z\" => 'z val', \"v\" => 'v val', \"x\" => 'x val', \"y\" => 'y val')", "optionals4", "x val", "y val", "z val", "v val"); + assertCall("CALL test_optionals4(\"v\" => 'v val', \"x\" => 'x val', \"y\" => 'y val', \"z\" => 'z val')", "optionals4", "x val", "y val", "z val", "v val"); assertCallFails("CALL test_optionals2()", "line 1:1: Required procedure argument 'x' is missing"); - assertCallFails("CALL test_optionals4(z => 'cd')", "line 1:1: Required procedure argument 'x' is missing"); - assertCallFails("CALL test_optionals4(z => 'cd', v => 'value')", "line 1:1: Required procedure argument 'x' is missing"); - assertCallFails("CALL test_optionals4(y => 'cd', v => 'value')", "line 1:1: Required procedure argument 'x' is missing"); + assertCallFails("CALL test_optionals4(\"z\" => 'cd')", "line 1:1: Required procedure argument 'x' is missing"); + assertCallFails("CALL test_optionals4(\"z\" => 'cd', \"v\" => 'value')", "line 1:1: Required procedure argument 'x' is missing"); + assertCallFails("CALL test_optionals4(\"y\" => 'cd', \"v\" => 'value')", "line 1:1: Required procedure argument 'x' is missing"); + } + + @Test + public void testNamedArguments() + { + assertCallFails("CALL test_argument_names(lower => 'a')", "line 1:26: Unknown argument name: LOWER"); + assertCallFails("CALL test_argument_names(LOWER => 'a')", "line 1:26: Unknown argument name: LOWER"); + assertCall("CALL test_argument_names(\"lower\" => 'a')", "names", "a", "b", "c", "d"); + assertCall("CALL test_argument_names(upper => 'b')", "names", "a", "b", "c", "d"); + assertCall("CALL test_argument_names(UPPER => 'b')", "names", "a", "b", "c", "d"); + assertCallFails("CALL test_argument_names(\"upper\" => 'b')", "line 1:26: Unknown argument name: upper"); + assertCallFails("CALL test_argument_names(MixeD => 'c')", "line 1:26: Unknown argument name: MIXED"); + assertCallFails("CALL test_argument_names(MIXED => 'c')", "line 1:26: Unknown argument name: MIXED"); + assertCall("CALL test_argument_names(\"MixeD\" => 'c')", "names", "a", "b", "c", "d"); + assertCall("CALL test_argument_names(\"with space\" => 'd')", "names", "a", "b", "c", "d"); } private void assertCall(@Language("SQL") String sql, String name, Object... arguments)