Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ public final class SystemSessionProperties
public static final String DYNAMIC_FILTERING_MAX_PER_DRIVER_ROW_COUNT = "dynamic_filtering_max_per_driver_row_count";
public static final String DYNAMIC_FILTERING_MAX_PER_DRIVER_SIZE = "dynamic_filtering_max_per_driver_size";
public static final String LEGACY_TYPE_COERCION_WARNING_ENABLED = "legacy_type_coercion_warning_enabled";
public static final String INLINE_SQL_FUNCTIONS = "inline_sql_functions";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -869,7 +870,12 @@ public SystemSessionProperties(
LEGACY_TYPE_COERCION_WARNING_ENABLED,
"Enable warning for query relying on legacy type coercion",
featuresConfig.isLegacyDateTimestampToVarcharCoercion(),
true));
true),
booleanProperty(
INLINE_SQL_FUNCTIONS,
"Inline SQL function definition at plan time",
featuresConfig.isInlineSqlFunctions(),
false));
}

public List<PropertyMetadata<?>> getSessionProperties()
Expand Down Expand Up @@ -1470,4 +1476,9 @@ public static boolean isLegacyTypeCoercionWarningEnabled(Session session)
{
return session.getSystemProperty(LEGACY_TYPE_COERCION_WARNING_ENABLED, Boolean.class);
}

public static boolean isInlineSqlFunctions(Session session)
{
return session.getSystemProperty(INLINE_SQL_FUNCTIONS, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CreateFunction;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.Return;
import com.facebook.presto.sql.tree.RoutineBody;
import com.facebook.presto.transaction.TransactionManager;
Expand Down Expand Up @@ -103,10 +105,32 @@ private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement, Me
Expression bodyExpression = ((Return) statement.getBody()).getExpression();
Type bodyType = analysis.getType(bodyExpression);

// Coerce expressions in body if necessary
bodyExpression = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
{
@Override
public Expression rewriteExpression(Expression expression, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Expression rewritten = treeRewriter.defaultRewrite(expression, null);

Type coercion = analysis.getCoercion(expression);
if (coercion != null) {
return new Cast(
rewritten,
coercion.getTypeSignature().toString(),
false,
analysis.isTypeOnlyCoercion(expression));
}
return rewritten;
}
}, bodyExpression, null);

if (!bodyType.equals(metadata.getType(returnType))) {
// Casting is safe-here, since we have verified that the actual type of the body is coercible to declared return type.
body = new Return(new Cast(bodyExpression, statement.getReturnType()));
// Casting is safe here, since we have verified at analysis time that the actual type of the body is coercible to declared return type.
bodyExpression = new Cast(bodyExpression, statement.getReturnType());
}

body = new Return(bodyExpression);
}

return new SqlInvokedFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ public class FeaturesConfig
private boolean preferDistributedUnion = true;
private boolean optimizeNullsInJoin;
private boolean pushdownDereferenceEnabled;
private boolean inlineSqlFunctions = true;

private String warnOnNoTableLayoutFilter = "";

Expand Down Expand Up @@ -1385,4 +1386,16 @@ public FeaturesConfig setWarnOnNoTableLayoutFilter(String warnOnNoTableLayoutFil
this.warnOnNoTableLayoutFilter = warnOnNoTableLayoutFilter;
return this;
}

public boolean isInlineSqlFunctions()
{
return inlineSqlFunctions;
}

@Config("inline-sql-functions")
public FeaturesConfig setInlineSqlFunctions(boolean inlineSqlFunctions)
{
this.inlineSqlFunctions = inlineSqlFunctions;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ protected Object visitFunctionCall(FunctionCall node, Object context)
result = functionInvoker.invoke(functionHandle, session.getSqlFunctionProperties(), argumentValues);
break;
case SQL:
Expression function = getSqlFunctionExpression(functionMetadata, (SqlInvokedScalarFunctionImplementation) metadata.getFunctionManager().getScalarFunctionImplementation(functionHandle), session.getSqlFunctionProperties(), node.getArguments());
Expression function = getSqlFunctionExpression(functionMetadata, (SqlInvokedScalarFunctionImplementation) metadata.getFunctionManager().getScalarFunctionImplementation(functionHandle), metadata, session.getSqlFunctionProperties(), node.getArguments());
ExpressionInterpreter functionInterpreter = new ExpressionInterpreter(
function,
metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.facebook.presto.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter;
import com.facebook.presto.sql.planner.iterative.rule.ImplementFilteredAggregations;
import com.facebook.presto.sql.planner.iterative.rule.InlineProjections;
import com.facebook.presto.sql.planner.iterative.rule.InlineSqlFunctions;
import com.facebook.presto.sql.planner.iterative.rule.MergeFilters;
import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithDistinct;
import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithSort;
Expand Down Expand Up @@ -330,6 +331,11 @@ public PlanOptimizers(
new LimitPushDown(), // Run the LimitPushDown after flattening set operators to make it easier to do the set flattening
new PruneUnreferencedOutputs(),
inlineProjections,
new IterativeOptimizer(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
new InlineSqlFunctions(metadata, sqlParser).rules()),
new IterativeOptimizer(
ruleStats,
statsCalculator,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionImplementationType;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.NodeRef;
import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.isInlineSqlFunctions;
import static com.facebook.presto.metadata.FunctionManager.qualifyFunctionName;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.relational.SqlFunctionUtils.getSqlFunctionExpression;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;

public class InlineSqlFunctions
extends ExpressionRewriteRuleSet
{
public InlineSqlFunctions(Metadata metadata, SqlParser sqlParser)
{
super(createRewrite(metadata, sqlParser));
}

private static ExpressionRewriter createRewrite(Metadata metadata, SqlParser sqlParser)
{
requireNonNull(metadata, "metadata is null");
requireNonNull(sqlParser, "sqlParser is null");

return (expression, context) -> InlineSqlFunctionsRewriter.rewrite(
expression,
context.getSession(),
metadata,
getExpressionTypes(
context.getSession(),
metadata,
sqlParser,
context.getVariableAllocator().getTypes(),
expression,
emptyList(),
context.getWarningCollector()));
}

@Override
public Set<Rule<?>> rules()
{
// Aggregations are not rewritten because they cannot have SQL functions
return ImmutableSet.of(
projectExpressionRewrite(),
filterExpressionRewrite(),
joinExpressionRewrite(),
valuesExpressionRewrite());
}

public static class InlineSqlFunctionsRewriter
{
private InlineSqlFunctionsRewriter() {}

public static Expression rewrite(Expression expression, Session session, Metadata metadata, Map<NodeRef<Expression>, Type> expressionTypes)
{
if (isInlineSqlFunctions(session)) {
return ExpressionTreeRewriter.rewriteWith(new Visitor(session, metadata, expressionTypes), expression);
}
return expression;
}

private static class Visitor
extends com.facebook.presto.sql.tree.ExpressionRewriter<Void>
{
private final Session session;
private final Metadata metadata;
private final Map<NodeRef<Expression>, Type> expressionTypes;

public Visitor(Session session, Metadata metadata, Map<NodeRef<Expression>, Type> expressionTypes)
{
this.session = requireNonNull(session, "session is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.expressionTypes = expressionTypes;
}

@Override
public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
List<Type> argumentTypes = new ArrayList<>();
List<Expression> rewrittenArguments = new ArrayList<>();
for (Expression argument : node.getArguments()) {
argumentTypes.add(expressionTypes.get(NodeRef.of(argument)));
rewrittenArguments.add(treeRewriter.rewrite(argument, context));
}

FunctionHandle functionHandle = metadata.getFunctionManager().resolveFunction(
session.getTransactionId(),
qualifyFunctionName(node.getName()),
fromTypes(argumentTypes));
FunctionMetadata functionMetadata = metadata.getFunctionManager().getFunctionMetadata(functionHandle);

if (functionMetadata.getImplementationType() != FunctionImplementationType.SQL) {
return new FunctionCall(node.getName(), rewrittenArguments);
}
return getSqlFunctionExpression(
functionMetadata,
(SqlInvokedScalarFunctionImplementation) metadata.getFunctionManager().getScalarFunctionImplementation(functionHandle),
metadata,
session.getSqlFunctionProperties(),
rewrittenArguments);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

public final class SqlFunctionUtils
{
private SqlFunctionUtils() {}

public static Expression getSqlFunctionExpression(FunctionMetadata functionMetadata, SqlInvokedScalarFunctionImplementation implementation, SqlFunctionProperties sqlFunctionProperties, List<Expression> arguments)
public static Expression getSqlFunctionExpression(FunctionMetadata functionMetadata, SqlInvokedScalarFunctionImplementation implementation, Metadata metadata, SqlFunctionProperties sqlFunctionProperties, List<Expression> arguments)
{
checkArgument(functionMetadata.getImplementationType().equals(SQL), format("Expect SQL function, get %s", functionMetadata.getImplementationType()));
checkArgument(functionMetadata.getArgumentNames().isPresent(), "ArgumentNames is missing");
Expression expression = normalizeParameters(functionMetadata.getArgumentNames().get(), parseSqlFunctionExpression(implementation, sqlFunctionProperties));
expression = coerceIfNecessary(functionMetadata, expression, sqlFunctionProperties, metadata);
return SqlFunctionArgumentBinder.bindFunctionArguments(expression, functionMetadata.getArgumentNames().get(), arguments);
}

Expand Down Expand Up @@ -226,7 +226,7 @@ public static Expression bindFunctionArguments(Expression function, List<String>
for (int i = 0; i < argumentNames.size(); i++) {
argumentBindings.put(argumentNames.get(i), argumentValues.get(i));
}
return ExpressionTreeRewriter.rewriteWith(new ExpressionFunctionVisitor(argumentBindings.build()), function);
return ExpressionTreeRewriter.rewriteWith(new ExpressionFunctionVisitor(), function, argumentBindings.build());
}

public static RowExpression bindFunctionArguments(RowExpression function, List<Optional<String>> argumentNames, List<RowExpression> argumentValues)
Expand Down Expand Up @@ -262,20 +262,26 @@ public RowExpression rewriteVariableReference(VariableReferenceExpression variab
}

private static class ExpressionFunctionVisitor
extends ExpressionRewriter<Void>
extends ExpressionRewriter<Map<String, Expression>>
{
private final Map<String, Expression> argumentBindings;

public ExpressionFunctionVisitor(Map<String, Expression> argumentBindings)
@Override
public Expression rewriteLambdaExpression(LambdaExpression lambda, Map<String, Expression> context, ExpressionTreeRewriter<Map<String, Expression>> treeRewriter)
{
this.argumentBindings = requireNonNull(argumentBindings, "argumentBindings is null");
ImmutableList<String> lambdaStringArguments = lambda.getArguments().stream()
.map(x -> x.getName().getValue())
.collect(toImmutableList());
ImmutableMap<String, Expression> lambdaContext = context.entrySet().stream()
.filter(entry -> !lambdaStringArguments.contains(entry.getKey()))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
Expression rewrittenBody = treeRewriter.rewrite(lambda.getBody(), lambdaContext);
return new LambdaExpression(lambda.getArguments(), rewrittenBody);
}

@Override
public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
public Expression rewriteIdentifier(Identifier node, Map<String, Expression> context, ExpressionTreeRewriter<Map<String, Expression>> treeRewriter)
{
if (argumentBindings.containsKey(node.getValue())) {
return argumentBindings.get(node.getValue());
if (context.containsKey(node.getValue())) {
return context.get(node.getValue());
}
return node;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ public void testDefaults()
.setOptimizeCommonSubExpressions(true)
.setPreferDistributedUnion(true)
.setOptimizeNullsInJoin(false)
.setWarnOnNoTableLayoutFilter(""));
.setWarnOnNoTableLayoutFilter("")
.setInlineSqlFunctions(true));
}

@Test
Expand Down Expand Up @@ -243,6 +244,7 @@ public void testExplicitPropertyMappings()
.put("prefer-distributed-union", "false")
.put("optimize-nulls-in-join", "true")
.put("warn-on-no-table-layout-filter", "ry@nlikestheyankees,ds")
.put("inline-sql-functions", "false")
.build();

FeaturesConfig expected = new FeaturesConfig()
Expand Down Expand Up @@ -338,7 +340,8 @@ public void testExplicitPropertyMappings()
.setOptimizeCommonSubExpressions(false)
.setPreferDistributedUnion(false)
.setOptimizeNullsInJoin(true)
.setWarnOnNoTableLayoutFilter("ry@nlikestheyankees,ds");
.setWarnOnNoTableLayoutFilter("ry@nlikestheyankees,ds")
.setInlineSqlFunctions(false);
assertFullMapping(properties, expected);
}

Expand Down
Loading