diff --git a/presto-main/src/main/java/com/facebook/presto/execution/CallTask.java b/presto-main/src/main/java/com/facebook/presto/execution/CallTask.java index e24d5075bf52a..308a66dac9d9a 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/CallTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/CallTask.java @@ -54,7 +54,9 @@ import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression; import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.base.Verify.verify; import static com.google.common.util.concurrent.Futures.immediateFuture; +import static java.lang.String.format; import static java.util.Arrays.asList; public class CallTask @@ -114,10 +116,14 @@ else if (i < procedure.getArguments().size()) { } } - // verify argument count - if (names.size() < positions.size()) { - throw new SemanticException(INVALID_PROCEDURE_ARGUMENTS, call, "Too few arguments for procedure"); - } + procedure.getArguments().stream() + .filter(Argument::isRequired) + .filter(argument -> !names.containsKey(argument.getName())) + .map(Argument::getName) + .findFirst() + .ifPresent(argument -> { + throw new SemanticException(INVALID_PROCEDURE_ARGUMENTS, call, format("Required procedure argument '%s' is missing", argument)); + }); // get argument values Object[] values = new Object[procedure.getArguments().size()]; @@ -135,6 +141,16 @@ else if (i < procedure.getArguments().size()) { values[index] = toTypeObjectValue(session, type, value); } + // fill values with optional arguments defaults + for (int i = 0; i < procedure.getArguments().size(); i++) { + Argument argument = procedure.getArguments().get(i); + + if (!names.containsKey(argument.getName())) { + verify(argument.isOptional()); + values[i] = toTypeObjectValue(session, metadata.getType(argument.getType()), argument.getDefaultValue()); + } + } + // validate arguments MethodType methodType = procedure.getMethodHandle().type(); for (int i = 0; i < procedure.getArguments().size(); i++) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/procedure/Procedure.java b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/Procedure.java index c84473732329c..658fff68b9314 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/procedure/Procedure.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/procedure/Procedure.java @@ -16,6 +16,8 @@ import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.spi.ConnectorSession; +import javax.annotation.Nullable; + import java.lang.invoke.MethodHandle; import java.util.ArrayList; import java.util.HashSet; @@ -23,6 +25,7 @@ import java.util.Set; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static java.lang.String.format; import static java.util.Collections.unmodifiableList; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; @@ -44,7 +47,13 @@ public Procedure(String schema, String name, List arguments, MethodHan Set names = new HashSet<>(); for (Argument argument : arguments) { - checkArgument(names.add(argument.getName()), "Duplicate argument name: " + argument.getName()); + checkArgument(names.add(argument.getName()), format("Duplicate argument name: '%s'", argument.getName())); + } + + for (int index = 1; index < arguments.size(); index++) { + if (arguments.get(index - 1).isOptional() && arguments.get(index).isRequired()) { + throw new IllegalArgumentException("Optional arguments should follow required ones"); + } } checkArgument(!methodHandle.isVarargsCollector(), "Method must have fixed arity"); @@ -93,16 +102,25 @@ public static class Argument { private final String name; private final TypeSignature type; + private final boolean required; + private final Object defaultValue; public Argument(String name, String type) { - this(name, parseTypeSignature(type)); + this(name, parseTypeSignature(type), true, null); + } + + public Argument(String name, String type, boolean required, @Nullable Object defaultValue) + { + this(name, parseTypeSignature(type), required, defaultValue); } - public Argument(String name, TypeSignature type) + public Argument(String name, TypeSignature type, boolean required, @Nullable Object defaultValue) { this.name = checkNotNullOrEmpty(name, "name"); this.type = requireNonNull(type, "type is null"); + this.required = required; + this.defaultValue = defaultValue; } public String getName() @@ -115,6 +133,24 @@ public TypeSignature getType() return type; } + public boolean isRequired() + { + return required; + } + + public boolean isOptional() + { + return !required; + } + + /** + * Argument default value in type's stack representation. + */ + public Object getDefaultValue() + { + return defaultValue; + } + @Override public String toString() { diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestingProcedures.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestingProcedures.java index 157e63abd9914..26561d70a75e3 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/TestingProcedures.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestingProcedures.java @@ -96,6 +96,30 @@ public void exception() throw new PrestoException(INVALID_PROCEDURE_ARGUMENT, "test exception from procedure"); } + @UsedByGeneratedCode + public void optionals(ConnectorSession session, String x) + { + tester.recordCalled("optionals", x); + } + + @UsedByGeneratedCode + public void optionals2(ConnectorSession session, String x, String y) + { + tester.recordCalled("optionals2", x, y); + } + + @UsedByGeneratedCode + public void optionals3(ConnectorSession session, String x, String y, String z) + { + tester.recordCalled("optionals3", x, y, z); + } + + @UsedByGeneratedCode + public void optionals4(ConnectorSession session, String x, String y, String z, String v) + { + tester.recordCalled("optionals4", x, y, z, v); + } + @UsedByGeneratedCode public void error() { @@ -124,6 +148,20 @@ public List getProcedures(String schema) new Argument("x", BIGINT)))) .add(procedure(schema, "test_session_last", "sessionLast", ImmutableList.of( new Argument("x", VARCHAR)))) + .add(procedure(schema, "test_optionals", "optionals", ImmutableList.of( + new Argument("x", VARCHAR, false, "hello")))) + .add(procedure(schema, "test_optionals2", "optionals2", ImmutableList.of( + new Argument("x", VARCHAR), + new Argument("y", VARCHAR, false, "world")))) + .add(procedure(schema, "test_optionals3", "optionals3", ImmutableList.of( + new Argument("x", VARCHAR, false, "this"), + new Argument("y", VARCHAR, false, "is"), + new Argument("z", VARCHAR, false, "default")))) + .add(procedure(schema, "test_optionals4", "optionals4", ImmutableList.of( + new Argument("x", VARCHAR), + new Argument("y", VARCHAR), + new Argument("z", VARCHAR, false, "z default"), + new Argument("v", VARCHAR, false, "v default")))) .add(procedure(schema, "test_exception", "exception", ImmutableList.of())) .add(procedure(schema, "test_error", "error", ImmutableList.of())) .build(); diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCall.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCall.java index ad50ab13ed040..9c5d7dae8b611 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCall.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCall.java @@ -113,8 +113,8 @@ public void testProcedureCall() assertCallFails("CALL test_args(123, 4.5, 'hello', null)", "Procedure argument cannot be null: q"); assertCallFails("CALL test_simple(123)", "line 1:1: Too many arguments for procedure"); - assertCallFails("CALL test_args(123, 4.5, 'hello')", "line 1:1: Too few arguments for procedure"); - assertCallFails("CALL test_args(x => 123, y => 4.5, q => true)", "line 1:1: Too few arguments for procedure"); + assertCallFails("CALL test_args(123, 4.5, 'hello')", "line 1:1: Required procedure argument 'q' is missing"); + assertCallFails("CALL test_args(x => 123, y => 4.5, q => true)", "line 1:1: Required procedure argument 'z' is missing"); assertCallFails("CALL test_args(123, 4.5, 'hello', q => true)", "line 1:1: Named and positional arguments cannot be mixed"); assertCallFails("CALL test_args(x => 3, x => 4)", "line 1:24: Duplicate procedure argument: x"); assertCallFails("CALL test_args(t => 404)", "line 1:16: Unknown argument name: t"); @@ -122,6 +122,37 @@ public void testProcedureCall() assertCallFails("CALL test_nulls(null, 123)", "line 1:23: Cannot cast type integer to varchar"); } + @Test + public void testProcedureCallWithOptionals() + { + // test_optionals(x => Optional['hello']) + // test_optionals2(x, y => Optional['world]) + // test_optionals3(x => Optional['this'], y => Optional['is'], z => Optional['default']) + // test_optionals4(x, y, z => Optional['z default'], v => Optional['v default']) + assertCall("CALL test_optionals()", "optionals", "hello"); + assertCall("CALL test_optionals(x => 'x')", "optionals", "x"); + assertCall("CALL test_optionals2(x => 'ab')", "optionals2", "ab", "world"); + assertCall("CALL test_optionals2('ab')", "optionals2", "ab", "world"); + assertCall("CALL test_optionals2(x => 'ab', y => 'cd')", "optionals2", "ab", "cd"); + assertCall("CALL test_optionals2(y => 'cd', x => 'ab')", "optionals2", "ab", "cd"); + assertCall("CALL test_optionals2('ab', 'cd')", "optionals2", "ab", "cd"); + assertCall("CALL test_optionals3(x => 'ab', z => 'cd')", "optionals3", "ab", "is", "cd"); + assertCall("CALL test_optionals3('ab', 'cd', 'ef')", "optionals3", "ab", "cd", "ef"); + assertCall("CALL test_optionals3('ab', 'cd')", "optionals3", "ab", "cd", "default"); + assertCall("CALL test_optionals3('ab')", "optionals3", "ab", "is", "default"); + assertCall("CALL test_optionals3(y => 'ab', z => 'cd')", "optionals3", "this", "ab", "cd"); + assertCall("CALL test_optionals3(z => 'cd')", "optionals3", "this", "is", "cd"); + assertCall("CALL test_optionals4('a', 'b')", "optionals4", "a", "b", "z default", "v default"); + assertCall("CALL test_optionals4(x => 'x val', y => 'y val')", "optionals4", "x val", "y val", "z default", "v default"); + assertCall("CALL test_optionals4(z => 'z val', v => 'v val', x => 'x val', y => 'y val')", "optionals4", "x val", "y val", "z val", "v val"); + assertCall("CALL test_optionals4(v => 'v val', x => 'x val', y => 'y val', z => 'z val')", "optionals4", "x val", "y val", "z val", "v val"); + + assertCallFails("CALL test_optionals2()", "line 1:1: Required procedure argument 'x' is missing"); + assertCallFails("CALL test_optionals4(z => 'cd')", "line 1:1: Required procedure argument 'x' is missing"); + assertCallFails("CALL test_optionals4(z => 'cd', v => 'value')", "line 1:1: Required procedure argument 'x' is missing"); + assertCallFails("CALL test_optionals4(y => 'cd', v => 'value')", "line 1:1: Required procedure argument 'x' is missing"); + } + private void assertCall(@Language("SQL") String sql, String name, Object... arguments) { tester.reset(); diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCreation.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCreation.java new file mode 100644 index 0000000000000..b8954e8e11b4e --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestProcedureCreation.java @@ -0,0 +1,153 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.tests; + +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.procedure.Procedure; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle; +import static com.facebook.presto.common.type.StandardTypes.VARCHAR; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@Test(singleThreaded = true) +public class TestProcedureCreation +{ + @Test + public void shouldThrowExceptionWhenOptionalArgumentIsNotLast() + { + assertThatThrownBy(() -> createTestProcedure(ImmutableList.of( + new Procedure.Argument("name", VARCHAR, false, null), + new Procedure.Argument("name2", VARCHAR, true, null)))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Optional arguments should follow required ones"); + + assertThatThrownBy(() -> createTestProcedure(ImmutableList.of( + new Procedure.Argument("name", VARCHAR, true, null), + new Procedure.Argument("name2", VARCHAR, true, null), + new Procedure.Argument("name3", VARCHAR, true, null), + new Procedure.Argument("name4", VARCHAR, false, null), + new Procedure.Argument("name5", VARCHAR, true, null)))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Optional arguments should follow required ones"); + } + + @Test + public void shouldThrowExceptionWhenArgumentNameRepeates() + { + assertThatThrownBy(() -> createTestProcedure(ImmutableList.of( + new Procedure.Argument("name", VARCHAR, false, null), + new Procedure.Argument("name", VARCHAR, true, null)))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Duplicate argument name: 'name'"); + } + + @Test + public void shouldThrowExceptionWhenProcedureIsNonVoid() + { + assertThatThrownBy(() -> new Procedure( + "schema", + "name", + ImmutableList.of(), + methodHandle(Procedures.class, "funWithoutArguments"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Method must return void"); + } + + @Test + public void shouldThrowExceptionWhenMethodHandleIsNull() + { + assertThatThrownBy(() -> new Procedure( + "schema", + "name", + ImmutableList.of(), + null)) + .isInstanceOf(NullPointerException.class) + .hasMessage("methodHandle is null"); + } + + @Test + public void shouldThrowExceptionWhenMethodHandleHasVarargs() + { + assertThatThrownBy(() -> new Procedure( + "schema", + "name", + ImmutableList.of(), + methodHandle(Procedures.class, "funWithVarargs", String[].class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Method must have fixed arity"); + } + + @Test + public void shouldThrowExceptionWhenArgumentCountDoesntMatch() + { + assertThatThrownBy(() -> new Procedure( + "schema", + "name", + ImmutableList.of( + new Procedure.Argument("name", VARCHAR, true, null), + new Procedure.Argument("name2", VARCHAR, true, null), + new Procedure.Argument("name3", VARCHAR, true, null)), + methodHandle(Procedures.class, "fun1", ConnectorSession.class, Object.class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Method parameter count must match arguments"); + } + + private static Procedure createTestProcedure(List arguments) + { + int argumentsCount = arguments.size(); + String functionName = "fun" + argumentsCount; + + Class[] clazzes = new Class[argumentsCount + 1]; + clazzes[0] = ConnectorSession.class; + + for (int i = 0; i < argumentsCount; i++) { + clazzes[i + 1] = Object.class; + } + + return new Procedure( + "schema", + "name", + arguments, + methodHandle(Procedures.class, functionName, clazzes)); + } + + public static class Procedures + { + public void fun0(ConnectorSession session) {} + + public void fun1(ConnectorSession session, Object arg1) {} + + public void fun2(ConnectorSession session, Object arg1, Object arg2) {} + + public void fun2(Object arg1, Object arg2) {} + + public void fun3(ConnectorSession session, Object arg1, Object arg2, Object arg3) {} + + public void fun4(ConnectorSession session, Object arg1, Object arg2, Object arg3, Object arg4) {} + + public void fun5(ConnectorSession session, Object arg1, Object arg2, Object arg3, Object arg4, Object arg5) {} + + public String funWithoutArguments() + { + return ""; + } + + public void funWithVarargs(String... arguments) {} + } +}