Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions core/trino-main/src/main/java/io/trino/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,17 @@ public SessionBuilder setSystemProperties(Map<String, String> systemProperties)
return this;
}

/**
* Sets catalog session properties, discarding any catalog properties previously set.
*/
public SessionBuilder setCatalogProperties(Map<String, Map<String, String>> catalogProperties)
{
requireNonNull(catalogProperties, "catalogProperties is null");
this.catalogSessionProperties.clear();
this.catalogSessionProperties.putAll(catalogProperties);
return this;
}

/**
* Sets a catalog property for the session. The property name and value must
* only contain characters from US-ASCII and must not be for '='.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
*/
package io.trino.sql.analyzer;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Streams;
import io.trino.Session;
Expand Down Expand Up @@ -150,6 +152,7 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.spi.StandardErrorCode.AMBIGUOUS_NAME;
Expand Down Expand Up @@ -297,7 +300,7 @@ public class ExpressionAnalyzer
private final Function<Node, ResolvedWindow> getResolvedWindow;
private final List<Field> sourceFields = new ArrayList<>();

private ExpressionAnalyzer(
public ExpressionAnalyzer(
PlannerContext plannerContext,
AccessControl accessControl,
StatementAnalyzerFactory statementAnalyzerFactory,
Expand Down Expand Up @@ -2218,19 +2221,16 @@ protected Type visitInPredicate(InPredicate node, StackableAstVisitorContext<Con
});
}

Type declaredValueType = process(value, context);

if (valueList instanceof InListExpression) {
process(valueList, context);
InListExpression inListExpression = (InListExpression) valueList;

coerceToSingleType(context,
Type type = coerceToSingleType(context,
"IN value and list items must be the same type: %s",
ImmutableList.<Expression>builder().add(value).addAll(inListExpression.getValues()).build());
setExpressionType(inListExpression, type);
}
else if (valueList instanceof SubqueryExpression) {
subqueryInPredicates.add(NodeRef.of(node));
analyzePredicateWithSubquery(node, declaredValueType, (SubqueryExpression) valueList, context);
analyzePredicateWithSubquery(node, process(value, context), (SubqueryExpression) valueList, context);
}
else {
throw new IllegalArgumentException("Unexpected value list type for InPredicate: " + node.getValueList().getClass().getName());
Expand All @@ -2239,15 +2239,6 @@ else if (valueList instanceof SubqueryExpression) {
return setExpressionType(node, BOOLEAN);
}

@Override
protected Type visitInListExpression(InListExpression node, StackableAstVisitorContext<Context> context)
{
Type type = coerceToSingleType(context, "All IN list values must be the same type: %s", node.getValues());

setExpressionType(node, type);
return type; // TODO: this really should a be relation type
}

@Override
protected Type visitSubqueryExpression(SubqueryExpression node, StackableAstVisitorContext<Context> context)
{
Expand Down Expand Up @@ -2568,22 +2559,32 @@ private Type coerceToSingleType(StackableAstVisitorContext<Context> context, Str
{
// determine super type
Type superType = UNKNOWN;

ListMultimap<Type, Expression> typeExpressions = ArrayListMultimap.create();
for (Expression expression : expressions) {
Optional<Type> newSuperType = typeCoercion.getCommonSuperType(superType, process(expression, context));
typeExpressions.put(process(expression, context), expression);
}

// We need an explicit copy to avoid ConcurrentModificationException
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove obsolete comment

Set<Type> types = typeExpressions.keySet();
Comment on lines 2568 to 2569
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't see explicit copy here.


for (Type type : types) {
Optional<Type> newSuperType = typeCoercion.getCommonSuperType(superType, type);
if (newSuperType.isEmpty()) {
throw semanticException(TYPE_MISMATCH, expression, message, superType);
throw semanticException(TYPE_MISMATCH, typeExpressions.get(type).get(0), message, superType);
}
superType = newSuperType.get();
}

// verify all expressions can be coerced to the superType
Comment thread
sopel39 marked this conversation as resolved.
Outdated
for (Expression expression : expressions) {
Type type = process(expression, context);
for (Type type : types) {
List<Expression> coercionCandidates = typeExpressions.get(type);

if (!type.equals(superType)) {
if (!typeCoercion.canCoerce(type, superType)) {
throw semanticException(TYPE_MISMATCH, expression, message, superType);
throw semanticException(TYPE_MISMATCH, coercionCandidates.get(0), message, superType);
}
addOrReplaceExpressionCoercion(expression, type, superType);
addOrReplaceExpressionsCoercion(coercionCandidates, type, superType);
}
}

Expand All @@ -2592,13 +2593,20 @@ private Type coerceToSingleType(StackableAstVisitorContext<Context> context, Str

private void addOrReplaceExpressionCoercion(Expression expression, Type type, Type superType)
{
NodeRef<Expression> ref = NodeRef.of(expression);
expressionCoercions.put(ref, superType);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List.of(expression) -> ImmutableList.of(expression)

addOrReplaceExpressionsCoercion(List.of(expression), type, superType);
Comment thread
wendigo marked this conversation as resolved.
Outdated
}

private void addOrReplaceExpressionsCoercion(List<Expression> expressions, Type type, Type superType)
{
Map<NodeRef<Expression>, Type> expressionRefTypes = expressions.stream()
.collect(toImmutableMap(NodeRef::of, expression -> superType));

expressionCoercions.putAll(expressionRefTypes);
if (typeCoercion.isTypeOnlyCoercion(type, superType)) {
typeOnlyCoercions.add(ref);
typeOnlyCoercions.addAll(expressionRefTypes.keySet());
}
else {
typeOnlyCoercions.remove(ref);
expressionRefTypes.keySet().forEach(typeOnlyCoercions::remove);
}
}
}
Expand Down Expand Up @@ -2773,8 +2781,13 @@ public static ExpressionAnalysis analyzeExpressions(
WarningCollector warningCollector,
QueryType queryType)
{
Analysis analysis = new Analysis(null, parameters, queryType);
ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, types, warningCollector);
return analyzeExpressions(
new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, new Analysis(null, parameters, queryType), session, types, warningCollector),
expressions);
}

public static ExpressionAnalysis analyzeExpressions(ExpressionAnalyzer analyzer, Iterable<Expression> expressions)
{
for (Expression expression : expressions) {
analyzer.analyze(
expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,14 @@ protected Object visitInPredicate(InPredicate node, Object context)

ResolvedFunction equalsOperator = metadata.resolveOperator(session, OperatorType.EQUAL, types(node.getValue(), valueList));
for (Expression expression : valueList.getValues()) {
if (value instanceof Expression && expression instanceof Literal) {
// skip interpreting of literal IN term since it cannot be compared
// with unresolved "value" and it cannot be simplified further
values.add(expression);
types.add(type(expression));
continue;
}

// Use process() instead of processWithExceptionHandling() for processing in-list items.
// Do not handle exceptions thrown while processing a single in-list expression,
// but fail the whole in-predicate evaluation.
Expand Down Expand Up @@ -680,7 +688,16 @@ else if (!found && result) {
return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, toExpression(value, type), simplifiedExpressionValues.get(0));
}

return new InPredicate(toExpression(value, type), new InListExpression(simplifiedExpressionValues));
Expression simplifiedValue = toExpression(value, type);
Expression simplifiedValueList = new InListExpression(simplifiedExpressionValues);
if (simplifiedValueList.equals(node.getValueList()) && simplifiedValue.equals(node.getValue())) {
// Do not create a new instance of InPredicate expression if it would be same as original expression.
// Creating a new instance of InPredicate would cause expression type cache miss, which
// is using node reference as a cache key.
return node;
}

return new InPredicate(simplifiedValue, simplifiedValueList);
}
if (hasNullValue) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,34 @@
*/
package io.trino.sql.planner;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.collect.cache.NonEvictableCache;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.AnalyzePropertyManager;
import io.trino.metadata.TablePropertyManager;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.QueryId;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.Analysis;
import io.trino.sql.analyzer.ExpressionAnalyzer;
import io.trino.sql.analyzer.StatementAnalyzerFactory;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;

import javax.inject.Inject;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.sql.analyzer.ExpressionAnalyzer.analyzeExpressions;
import static io.trino.sql.analyzer.QueryType.OTHERS;
import static io.trino.sql.analyzer.StatementAnalyzerFactory.createTestingStatementAnalyzerFactory;
Expand All @@ -45,6 +56,14 @@ public class TypeAnalyzer
private final PlannerContext plannerContext;
private final StatementAnalyzerFactory statementAnalyzerFactory;

private final NonEvictableCache<QueryId, QueryScopedCachedTypeAnalyzer> typeAnalyzersCache = buildNonEvictableCache(
CacheBuilder.newBuilder()
// Try to evict queries cache as soon as possible to keep cache relatively small
.expireAfterAccess(15, TimeUnit.SECONDS)
.maximumSize(256)
.softValues()
.recordStats());

@Inject
public TypeAnalyzer(PlannerContext plannerContext, StatementAnalyzerFactory statementAnalyzerFactory)
{
Expand All @@ -54,17 +73,13 @@ public TypeAnalyzer(PlannerContext plannerContext, StatementAnalyzerFactory stat

public Map<NodeRef<Expression>, Type> getTypes(Session session, TypeProvider inputTypes, Iterable<Expression> expressions)
{
return analyzeExpressions(
session,
plannerContext,
statementAnalyzerFactory,
new AllowAllAccessControl(),
inputTypes,
expressions,
ImmutableMap.of(),
WarningCollector.NOOP,
OTHERS)
.getExpressionTypes();
try {
return typeAnalyzersCache.get(session.getQueryId(), () -> new QueryScopedCachedTypeAnalyzer(plannerContext, statementAnalyzerFactory))
.getTypes(session, inputTypes, ImmutableList.copyOf(expressions));
}
catch (ExecutionException e) {
throw new RuntimeException(e);
}
}

public Map<NodeRef<Expression>, Type> getTypes(Session session, TypeProvider inputTypes, Expression expression)
Expand All @@ -87,4 +102,59 @@ public static TypeAnalyzer createTestingTypeAnalyzer(PlannerContext plannerConte
new TablePropertyManager(),
new AnalyzePropertyManager()));
}

private static class QueryScopedCachedTypeAnalyzer
{
private final Cache<NodeRef<Expression>, Type> typesCache = buildNonEvictableCache(CacheBuilder.newBuilder());
private PlannerContext plannerContext;
private StatementAnalyzerFactory statementAnalyzerFactory;

private QueryScopedCachedTypeAnalyzer(PlannerContext plannerContext, StatementAnalyzerFactory statementAnalyzerFactory)
{
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
this.statementAnalyzerFactory = requireNonNull(statementAnalyzerFactory, "statementAnalyzerFactory is null");
}

private Map<NodeRef<Expression>, Type> getTypes(Session session, TypeProvider inputTypes, List<Expression> expressions)
{
List<NodeRef<Expression>> expressionsToResolve = collectExpressions(expressions);
Map<NodeRef<Expression>, Type> cachedTypes = typesCache.getAllPresent(expressionsToResolve);

// All expressions were resolved from cache
if (cachedTypes.size() == expressionsToResolve.size()) {
return cachedTypes;
}

Map<NodeRef<Expression>, Type> resolvedTypes = analyzeExpressions(createExpressionAnalyzer(session, plannerContext, statementAnalyzerFactory, inputTypes), expressions)
.getExpressionTypes();

typesCache.putAll(resolvedTypes);
return resolvedTypes;
}

private static ExpressionAnalyzer createExpressionAnalyzer(Session session,
PlannerContext plannerContext,
StatementAnalyzerFactory statementAnalyzerFactory,
TypeProvider types)
{
return new ExpressionAnalyzer(plannerContext, new AllowAllAccessControl(), statementAnalyzerFactory, new Analysis(null, ImmutableMap.of(), OTHERS), session, types, WarningCollector.NOOP);
}

private static ImmutableList<NodeRef<Expression>> collectExpressions(Iterable<? extends Node> expressions)
{
ImmutableList.Builder<NodeRef<Expression>> builder = ImmutableList.builder();

for (Node expression : expressions) {
if (expression instanceof Expression) {
builder.add(NodeRef.of((Expression) expression));
}

if (!expression.getChildren().isEmpty()) {
builder.addAll(collectExpressions(expression.getChildren()));
}
}

return builder.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
*/
package io.trino.sql.planner.plan;

import java.util.List;
import com.google.common.collect.ImmutableList;

import static com.google.common.base.Verify.verifyNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren;

public abstract class SimplePlanRewriter<C>
Expand Down Expand Up @@ -69,11 +68,9 @@ public PlanNode defaultRewrite(PlanNode node)
*/
public PlanNode defaultRewrite(PlanNode node, C context)
{
List<PlanNode> children = node.getSources().stream()
.map(child -> rewrite(child, context))
.collect(toImmutableList());

return replaceChildren(node, children);
ImmutableList.Builder<PlanNode> children = ImmutableList.builderWithExpectedSize(node.getSources().size());
Comment thread
findepi marked this conversation as resolved.
Outdated
node.getSources().forEach(source -> children.add(rewrite(source, context)));
return replaceChildren(node, children.build());
}

/**
Expand Down
Loading