Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@
import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression;
import static com.facebook.presto.util.Failures.checkCondition;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.base.Verify.verify;
import static com.google.common.util.concurrent.Futures.immediateFuture;
import static java.lang.String.format;
import static java.util.Arrays.asList;

public class CallTask
Expand Down Expand Up @@ -114,10 +116,14 @@ else if (i < procedure.getArguments().size()) {
}
}

// verify argument count
if (names.size() < positions.size()) {
throw new SemanticException(INVALID_PROCEDURE_ARGUMENTS, call, "Too few arguments for procedure");
}
procedure.getArguments().stream()
.filter(Argument::isRequired)
.filter(argument -> !names.containsKey(argument.getName()))
.map(Argument::getName)
.findFirst()
.ifPresent(argument -> {
throw new SemanticException(INVALID_PROCEDURE_ARGUMENTS, call, format("Required procedure argument '%s' is missing", argument));
});

// get argument values
Object[] values = new Object[procedure.getArguments().size()];
Expand All @@ -135,6 +141,16 @@ else if (i < procedure.getArguments().size()) {
values[index] = toTypeObjectValue(session, type, value);
}

// fill values with optional arguments defaults
for (int i = 0; i < procedure.getArguments().size(); i++) {
Argument argument = procedure.getArguments().get(i);

if (!names.containsKey(argument.getName())) {
verify(argument.isOptional());
values[i] = toTypeObjectValue(session, metadata.getType(argument.getType()), argument.getDefaultValue());
}
}

// validate arguments
MethodType methodType = procedure.getMethodHandle().type();
for (int i = 0; i < procedure.getArguments().size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.spi.ConnectorSession;

import javax.annotation.Nullable;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static java.lang.String.format;
import static java.util.Collections.unmodifiableList;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
Expand All @@ -44,7 +47,13 @@ public Procedure(String schema, String name, List<Argument> arguments, MethodHan

Set<String> names = new HashSet<>();
for (Argument argument : arguments) {
checkArgument(names.add(argument.getName()), "Duplicate argument name: " + argument.getName());
checkArgument(names.add(argument.getName()), format("Duplicate argument name: '%s'", argument.getName()));
}

for (int index = 1; index < arguments.size(); index++) {
if (arguments.get(index - 1).isOptional() && arguments.get(index).isRequired()) {
throw new IllegalArgumentException("Optional arguments should follow required ones");
}
}

checkArgument(!methodHandle.isVarargsCollector(), "Method must have fixed arity");
Expand Down Expand Up @@ -93,16 +102,25 @@ public static class Argument
{
private final String name;
private final TypeSignature type;
private final boolean required;
private final Object defaultValue;

public Argument(String name, String type)
{
this(name, parseTypeSignature(type));
this(name, parseTypeSignature(type), true, null);
}

public Argument(String name, String type, boolean required, @Nullable Object defaultValue)
{
this(name, parseTypeSignature(type), required, defaultValue);
}

public Argument(String name, TypeSignature type)
public Argument(String name, TypeSignature type, boolean required, @Nullable Object defaultValue)
{
this.name = checkNotNullOrEmpty(name, "name");
this.type = requireNonNull(type, "type is null");
this.required = required;
this.defaultValue = defaultValue;
}

public String getName()
Expand All @@ -115,6 +133,24 @@ public TypeSignature getType()
return type;
}

public boolean isRequired()
{
return required;
}

public boolean isOptional()
{
return !required;
}

/**
* Argument default value in type's stack representation.
*/
public Object getDefaultValue()
{
return defaultValue;
}

@Override
public String toString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,30 @@ public void exception()
throw new PrestoException(INVALID_PROCEDURE_ARGUMENT, "test exception from procedure");
}

@UsedByGeneratedCode
public void optionals(ConnectorSession session, String x)
{
tester.recordCalled("optionals", x);
}

@UsedByGeneratedCode
public void optionals2(ConnectorSession session, String x, String y)
{
tester.recordCalled("optionals2", x, y);
}

@UsedByGeneratedCode
public void optionals3(ConnectorSession session, String x, String y, String z)
{
tester.recordCalled("optionals3", x, y, z);
}

@UsedByGeneratedCode
public void optionals4(ConnectorSession session, String x, String y, String z, String v)
{
tester.recordCalled("optionals4", x, y, z, v);
}

@UsedByGeneratedCode
public void error()
{
Expand Down Expand Up @@ -124,6 +148,20 @@ public List<Procedure> getProcedures(String schema)
new Argument("x", BIGINT))))
.add(procedure(schema, "test_session_last", "sessionLast", ImmutableList.of(
new Argument("x", VARCHAR))))
.add(procedure(schema, "test_optionals", "optionals", ImmutableList.of(
new Argument("x", VARCHAR, false, "hello"))))
.add(procedure(schema, "test_optionals2", "optionals2", ImmutableList.of(
new Argument("x", VARCHAR),
new Argument("y", VARCHAR, false, "world"))))
.add(procedure(schema, "test_optionals3", "optionals3", ImmutableList.of(
new Argument("x", VARCHAR, false, "this"),
new Argument("y", VARCHAR, false, "is"),
new Argument("z", VARCHAR, false, "default"))))
.add(procedure(schema, "test_optionals4", "optionals4", ImmutableList.of(
new Argument("x", VARCHAR),
new Argument("y", VARCHAR),
new Argument("z", VARCHAR, false, "z default"),
new Argument("v", VARCHAR, false, "v default"))))
.add(procedure(schema, "test_exception", "exception", ImmutableList.of()))
.add(procedure(schema, "test_error", "error", ImmutableList.of()))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,46 @@ public void testProcedureCall()
assertCallFails("CALL test_args(123, 4.5, 'hello', null)", "Procedure argument cannot be null: q");

assertCallFails("CALL test_simple(123)", "line 1:1: Too many arguments for procedure");
assertCallFails("CALL test_args(123, 4.5, 'hello')", "line 1:1: Too few arguments for procedure");
assertCallFails("CALL test_args(x => 123, y => 4.5, q => true)", "line 1:1: Too few arguments for procedure");
assertCallFails("CALL test_args(123, 4.5, 'hello')", "line 1:1: Required procedure argument 'q' is missing");
assertCallFails("CALL test_args(x => 123, y => 4.5, q => true)", "line 1:1: Required procedure argument 'z' is missing");
assertCallFails("CALL test_args(123, 4.5, 'hello', q => true)", "line 1:1: Named and positional arguments cannot be mixed");
assertCallFails("CALL test_args(x => 3, x => 4)", "line 1:24: Duplicate procedure argument: x");
assertCallFails("CALL test_args(t => 404)", "line 1:16: Unknown argument name: t");
assertCallFails("CALL test_nulls('hello', null)", "line 1:17: Cannot cast type varchar(5) to bigint");
assertCallFails("CALL test_nulls(null, 123)", "line 1:23: Cannot cast type integer to varchar");
}

@Test
public void testProcedureCallWithOptionals()
{
// test_optionals(x => Optional['hello'])
// test_optionals2(x, y => Optional['world])
// test_optionals3(x => Optional['this'], y => Optional['is'], z => Optional['default'])
// test_optionals4(x, y, z => Optional['z default'], v => Optional['v default'])
assertCall("CALL test_optionals()", "optionals", "hello");
assertCall("CALL test_optionals(x => 'x')", "optionals", "x");
assertCall("CALL test_optionals2(x => 'ab')", "optionals2", "ab", "world");
assertCall("CALL test_optionals2('ab')", "optionals2", "ab", "world");
assertCall("CALL test_optionals2(x => 'ab', y => 'cd')", "optionals2", "ab", "cd");
assertCall("CALL test_optionals2(y => 'cd', x => 'ab')", "optionals2", "ab", "cd");
assertCall("CALL test_optionals2('ab', 'cd')", "optionals2", "ab", "cd");
assertCall("CALL test_optionals3(x => 'ab', z => 'cd')", "optionals3", "ab", "is", "cd");
assertCall("CALL test_optionals3('ab', 'cd', 'ef')", "optionals3", "ab", "cd", "ef");
assertCall("CALL test_optionals3('ab', 'cd')", "optionals3", "ab", "cd", "default");
assertCall("CALL test_optionals3('ab')", "optionals3", "ab", "is", "default");
assertCall("CALL test_optionals3(y => 'ab', z => 'cd')", "optionals3", "this", "ab", "cd");
assertCall("CALL test_optionals3(z => 'cd')", "optionals3", "this", "is", "cd");
assertCall("CALL test_optionals4('a', 'b')", "optionals4", "a", "b", "z default", "v default");
assertCall("CALL test_optionals4(x => 'x val', y => 'y val')", "optionals4", "x val", "y val", "z default", "v default");
assertCall("CALL test_optionals4(z => 'z val', v => 'v val', x => 'x val', y => 'y val')", "optionals4", "x val", "y val", "z val", "v val");
assertCall("CALL test_optionals4(v => 'v val', x => 'x val', y => 'y val', z => 'z val')", "optionals4", "x val", "y val", "z val", "v val");

assertCallFails("CALL test_optionals2()", "line 1:1: Required procedure argument 'x' is missing");
assertCallFails("CALL test_optionals4(z => 'cd')", "line 1:1: Required procedure argument 'x' is missing");
assertCallFails("CALL test_optionals4(z => 'cd', v => 'value')", "line 1:1: Required procedure argument 'x' is missing");
assertCallFails("CALL test_optionals4(y => 'cd', v => 'value')", "line 1:1: Required procedure argument 'x' is missing");
}

private void assertCall(@Language("SQL") String sql, String name, Object... arguments)
{
tester.reset();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.tests;

import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.procedure.Procedure;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

import java.util.List;

import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle;
import static com.facebook.presto.common.type.StandardTypes.VARCHAR;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

@Test(singleThreaded = true)
public class TestProcedureCreation
{
@Test
public void shouldThrowExceptionWhenOptionalArgumentIsNotLast()
{
assertThatThrownBy(() -> createTestProcedure(ImmutableList.of(
new Procedure.Argument("name", VARCHAR, false, null),
new Procedure.Argument("name2", VARCHAR, true, null))))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Optional arguments should follow required ones");

assertThatThrownBy(() -> createTestProcedure(ImmutableList.of(
new Procedure.Argument("name", VARCHAR, true, null),
new Procedure.Argument("name2", VARCHAR, true, null),
new Procedure.Argument("name3", VARCHAR, true, null),
new Procedure.Argument("name4", VARCHAR, false, null),
new Procedure.Argument("name5", VARCHAR, true, null))))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Optional arguments should follow required ones");
}

@Test
public void shouldThrowExceptionWhenArgumentNameRepeates()
{
assertThatThrownBy(() -> createTestProcedure(ImmutableList.of(
new Procedure.Argument("name", VARCHAR, false, null),
new Procedure.Argument("name", VARCHAR, true, null))))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Duplicate argument name: 'name'");
}

@Test
public void shouldThrowExceptionWhenProcedureIsNonVoid()
{
assertThatThrownBy(() -> new Procedure(
"schema",
"name",
ImmutableList.of(),
methodHandle(Procedures.class, "funWithoutArguments")))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Method must return void");
}

@Test
public void shouldThrowExceptionWhenMethodHandleIsNull()
{
assertThatThrownBy(() -> new Procedure(
"schema",
"name",
ImmutableList.of(),
null))
.isInstanceOf(NullPointerException.class)
.hasMessage("methodHandle is null");
}

@Test
public void shouldThrowExceptionWhenMethodHandleHasVarargs()
{
assertThatThrownBy(() -> new Procedure(
"schema",
"name",
ImmutableList.of(),
methodHandle(Procedures.class, "funWithVarargs", String[].class)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Method must have fixed arity");
}

@Test
public void shouldThrowExceptionWhenArgumentCountDoesntMatch()
{
assertThatThrownBy(() -> new Procedure(
"schema",
"name",
ImmutableList.of(
new Procedure.Argument("name", VARCHAR, true, null),
new Procedure.Argument("name2", VARCHAR, true, null),
new Procedure.Argument("name3", VARCHAR, true, null)),
methodHandle(Procedures.class, "fun1", ConnectorSession.class, Object.class)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Method parameter count must match arguments");
}

private static Procedure createTestProcedure(List<Procedure.Argument> arguments)
{
int argumentsCount = arguments.size();
String functionName = "fun" + argumentsCount;

Class<?>[] clazzes = new Class<?>[argumentsCount + 1];
clazzes[0] = ConnectorSession.class;

for (int i = 0; i < argumentsCount; i++) {
clazzes[i + 1] = Object.class;
}

return new Procedure(
"schema",
"name",
arguments,
methodHandle(Procedures.class, functionName, clazzes));
}

public static class Procedures
{
public void fun0(ConnectorSession session) {}

public void fun1(ConnectorSession session, Object arg1) {}

public void fun2(ConnectorSession session, Object arg1, Object arg2) {}

public void fun2(Object arg1, Object arg2) {}

public void fun3(ConnectorSession session, Object arg1, Object arg2, Object arg3) {}

public void fun4(ConnectorSession session, Object arg1, Object arg2, Object arg3, Object arg4) {}

public void fun5(ConnectorSession session, Object arg1, Object arg2, Object arg3, Object arg4, Object arg5) {}

public String funWithoutArguments()
{
return "";
}

public void funWithVarargs(String... arguments) {}
}
}