Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -52,7 +53,7 @@ public static Expression translate(ConnectorExpression expression, Map<String, S
return new ConnectorToSqlExpressionTranslator(variableMappings, literalEncoder).translate(expression);
}

public static ConnectorExpression translate(Session session, Expression expression, TypeAnalyzer types, TypeProvider inputTypes)
public static Optional<ConnectorExpression> translate(Session session, Expression expression, TypeAnalyzer types, TypeProvider inputTypes)
{
return new SqlToConnectorExpressionTranslator(types.getTypes(session, inputTypes, expression))
.process(expression);
Expand Down Expand Up @@ -82,7 +83,7 @@ public Expression translate(ConnectorExpression expression)
if (expression instanceof FieldDereference) {
FieldDereference dereference = (FieldDereference) expression;

RowType type = (RowType) expression.getType();
RowType type = (RowType) dereference.getTarget().getType();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to test this?

also, cmt title is too long. Maybe just Fix translation of FieldDereference to DereferenceExpression.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@findepi Fixed the message. Added a test in TestConnectorExpressionTranslator. I haven't added literals for the test since the translation for them would rely on how ExpressionAnalyzer and LiteralEncoder.

String name = type.getFields().get(dereference.getField()).getName().get();
return new DereferenceExpression(translate(dereference.getTarget()), new Identifier(name));
}
Expand All @@ -91,73 +92,78 @@ public Expression translate(ConnectorExpression expression)
}
}

private static class SqlToConnectorExpressionTranslator
extends AstVisitor<ConnectorExpression, Void>
static class SqlToConnectorExpressionTranslator
extends AstVisitor<Optional<ConnectorExpression>, Void>
{
private final Map<NodeRef<Expression>, Type> types;

private SqlToConnectorExpressionTranslator(Map<NodeRef<Expression>, Type> types)
public SqlToConnectorExpressionTranslator(Map<NodeRef<Expression>, Type> types)
{
this.types = requireNonNull(types, "types is null");
}

@Override
protected ConnectorExpression visitSymbolReference(SymbolReference node, Void context)
protected Optional<ConnectorExpression> visitSymbolReference(SymbolReference node, Void context)
{
return new Variable(node.getName(), typeOf(node));
return Optional.of(new Variable(node.getName(), typeOf(node)));
}

@Override
protected ConnectorExpression visitBooleanLiteral(BooleanLiteral node, Void context)
protected Optional<ConnectorExpression> visitBooleanLiteral(BooleanLiteral node, Void context)
{
return new Constant(node.getValue(), typeOf(node));
return Optional.of(new Constant(node.getValue(), typeOf(node)));
}

@Override
protected ConnectorExpression visitStringLiteral(StringLiteral node, Void context)
protected Optional<ConnectorExpression> visitStringLiteral(StringLiteral node, Void context)
{
return new Constant(node.getSlice(), typeOf(node));
return Optional.of(new Constant(node.getSlice(), typeOf(node)));
}

@Override
protected ConnectorExpression visitDoubleLiteral(DoubleLiteral node, Void context)
protected Optional<ConnectorExpression> visitDoubleLiteral(DoubleLiteral node, Void context)
{
return new Constant(node.getValue(), typeOf(node));
return Optional.of(new Constant(node.getValue(), typeOf(node)));
}

@Override
protected ConnectorExpression visitDecimalLiteral(DecimalLiteral node, Void context)
protected Optional<ConnectorExpression> visitDecimalLiteral(DecimalLiteral node, Void context)
{
return new Constant(Decimals.parse(node.getValue()).getObject(), typeOf(node));
return Optional.of(new Constant(Decimals.parse(node.getValue()).getObject(), typeOf(node)));
}

@Override
protected ConnectorExpression visitCharLiteral(CharLiteral node, Void context)
protected Optional<ConnectorExpression> visitCharLiteral(CharLiteral node, Void context)
{
return new Constant(node.getSlice(), typeOf(node));
return Optional.of(new Constant(node.getSlice(), typeOf(node)));
}

@Override
protected ConnectorExpression visitBinaryLiteral(BinaryLiteral node, Void context)
protected Optional<ConnectorExpression> visitBinaryLiteral(BinaryLiteral node, Void context)
{
return new Constant(node.getValue(), typeOf(node));
return Optional.of(new Constant(node.getValue(), typeOf(node)));
}

@Override
protected ConnectorExpression visitLongLiteral(LongLiteral node, Void context)
protected Optional<ConnectorExpression> visitLongLiteral(LongLiteral node, Void context)
{
return new Constant(node.getValue(), typeOf(node));
return Optional.of(new Constant(node.getValue(), typeOf(node)));
}

@Override
protected ConnectorExpression visitNullLiteral(NullLiteral node, Void context)
protected Optional<ConnectorExpression> visitNullLiteral(NullLiteral node, Void context)
{
return new Constant(null, typeOf(node));
return Optional.of(new Constant(null, typeOf(node)));
}

@Override
protected ConnectorExpression visitDereferenceExpression(DereferenceExpression node, Void context)
protected Optional<ConnectorExpression> visitDereferenceExpression(DereferenceExpression node, Void context)
{
Optional<ConnectorExpression> translatedBase = process(node.getBase());
if (!translatedBase.isPresent()) {
return Optional.empty();
}

RowType rowType = (RowType) typeOf(node.getBase());
String fieldName = node.getField().getValue();
List<RowType.Field> fields = rowType.getFields();
Expand All @@ -172,13 +178,13 @@ protected ConnectorExpression visitDereferenceExpression(DereferenceExpression n

checkState(index >= 0, "could not find field name: %s", node.getField());

return new FieldDereference(typeOf(node), process(node.getBase()), index);
return Optional.of(new FieldDereference(typeOf(node), translatedBase.get(), index));
}

@Override
protected ConnectorExpression visitExpression(Expression node, Void context)
protected Optional<ConnectorExpression> visitExpression(Expression node, Void context)
{
throw new UnsupportedOperationException("not yet implemented: expression translator for " + node.getClass().getName());
return Optional.empty();
}

private Type typeOf(Expression node)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.spi.expression;

import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.tree.AstVisitor;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.LambdaExpression;
import io.prestosql.sql.tree.NodeRef;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import static java.util.Objects.requireNonNull;

public class PartialTranslator
{
private PartialTranslator() {}

/**
* Produces {@link ConnectorExpression} translations for disjoint components in the {@param inputExpression} in a
* top-down manner. i.e. if an expression node is translatable, we do not consider its children.
*/
public static Map<NodeRef<Expression>, ConnectorExpression> extractPartialTranslations(
Expression inputExpression,
Session session,
TypeAnalyzer typeAnalyzer,
TypeProvider typeProvider)
{
requireNonNull(inputExpression, "expressions is null");
requireNonNull(session, "session is null");
requireNonNull(typeAnalyzer, "typeAnalyzer is null");
requireNonNull(typeProvider, "typeProvider is null");

Map<NodeRef<Expression>, ConnectorExpression> partialTranslations = new HashMap<>();
new Visitor(typeAnalyzer.getTypes(session, typeProvider, inputExpression), partialTranslations).process(inputExpression);
return ImmutableMap.copyOf(partialTranslations);
}

private static class Visitor
extends AstVisitor<Void, Void>
{
private final Map<NodeRef<Expression>, ConnectorExpression> translatedSubExpressions;
private final ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator translator;

Visitor(Map<NodeRef<Expression>, Type> types, Map<NodeRef<Expression>, ConnectorExpression> translatedSubExpressions)
{
requireNonNull(types, "types is null");
this.translatedSubExpressions = requireNonNull(translatedSubExpressions, "translatedSubExpressions is null");
this.translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(types);
}

@Override
public Void visitExpression(Expression node, Void context)
{
Optional<ConnectorExpression> result = translator.process(node);

if (result.isPresent()) {
translatedSubExpressions.put(NodeRef.of(node), result.get());
}
else {
node.getChildren().forEach(this::process);
}

return null;
}

// TODO support lambda expressions for partial projection
@Override
public Void visitLambdaExpression(LambdaExpression functionCall, Void context)
{
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public static Expression replaceExpression(Expression expression, Map<NodeRef<Ex

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rephrase the commit message: "Update constructor visibility in ReferenceAwareExpressionNodeInliner"

private final Map<NodeRef<Expression>, Expression> mappings;

public ReferenceAwareExpressionNodeInliner(Map<NodeRef<Expression>, Expression> mappings)
private ReferenceAwareExpressionNodeInliner(Map<NodeRef<Expression>, Expression> mappings)
{
this.mappings = ImmutableMap.copyOf(requireNonNull(mappings, "mappings is null"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
*/
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
Expand All @@ -30,16 +32,20 @@
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.NodeRef;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.prestosql.matching.Capture.newCapture;
import static io.prestosql.spi.expression.PartialTranslator.extractPartialTranslations;
import static io.prestosql.sql.planner.ReferenceAwareExpressionNodeInliner.replaceExpression;
import static io.prestosql.sql.planner.plan.Patterns.project;
import static io.prestosql.sql.planner.plan.Patterns.source;
import static io.prestosql.sql.planner.plan.Patterns.tableScan;
Expand Down Expand Up @@ -71,37 +77,45 @@ public Result apply(ProjectNode project, Captures captures, Context context)
{
TableScanNode tableScan = captures.get(TABLE_SCAN);

List<ConnectorExpression> projections;
try {
projections = project.getAssignments()
.getExpressions().stream()
.map(expression -> ConnectorExpressionTranslator.translate(
context.getSession(),
expression,
typeAnalyzer,
context.getSymbolAllocator().getTypes()))
.collect(toImmutableList());
}
catch (UnsupportedOperationException e) {
// some expression not supported by translator, skip
// TODO: Support pushing down the expressions that could be translated
// TODO: A possible approach might be:
// 1. For expressions that could not be translated, extract column references
// 2. Provide those column references as part of the call to applyProjection
// 3. Re-assemble a projection based on the new projections + un-translateble projections
// rewritten in terms of the new assignments for the columns passed in #2
return Result.empty();
Map<Symbol, Expression> inputExpressions = project.getAssignments().getMap();

ImmutableList.Builder<NodeRef<Expression>> nodeReferencesBuilder = ImmutableList.builder();
ImmutableList.Builder<ConnectorExpression> partialProjectionsBuilder = ImmutableList.builder();

// Extract translatable components from projection expressions. Prepare a mapping from these internal
// expression nodes to corresponding ConnectorExpression translations.
for (Map.Entry<Symbol, Expression> expression : inputExpressions.entrySet()) {
Map<NodeRef<Expression>, ConnectorExpression> partialTranslations = extractPartialTranslations(
expression.getValue(),
context.getSession(),
typeAnalyzer,
context.getSymbolAllocator().getTypes());

partialTranslations.forEach((nodeRef, expr) -> {
nodeReferencesBuilder.add(nodeRef);
partialProjectionsBuilder.add(expr);
});
}

List<NodeRef<Expression>> nodesForPartialProjections = nodeReferencesBuilder.build();
List<ConnectorExpression> connectorPartialProjections = partialProjectionsBuilder.build();

Map<String, ColumnHandle> assignments = tableScan.getAssignments()
.entrySet().stream()
.collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue));

Optional<ProjectionApplicationResult<TableHandle>> result = metadata.applyProjection(context.getSession(), tableScan.getTable(), projections, assignments);
Optional<ProjectionApplicationResult<TableHandle>> result = metadata.applyProjection(context.getSession(), tableScan.getTable(), connectorPartialProjections, assignments);

if (!result.isPresent()) {
return Result.empty();
}

List<ConnectorExpression> newConnectorPartialProjections = result.get().getProjections();
checkState(newConnectorPartialProjections.size() == connectorPartialProjections.size(),
"Mismatch between input and output projections from the connector: expected %s but got %s",
connectorPartialProjections.size(),
newConnectorPartialProjections.size());

List<Symbol> newScanOutputs = new ArrayList<>();
Map<Symbol, ColumnHandle> newScanAssignments = new HashMap<>();
Map<String, Symbol> variableMappings = new HashMap<>();
Expand All @@ -113,16 +127,23 @@ public Result apply(ProjectNode project, Captures captures, Context context)
variableMappings.put(assignment.getVariable(), symbol);
}

// TODO: ensure newProjections.size == original projections.size

List<Expression> newProjections = result.get().getProjections().stream()
// Translate partial connector projections back to new partial projections
List<Expression> newPartialProjections = newConnectorPartialProjections.stream()
.map(expression -> ConnectorExpressionTranslator.translate(expression, variableMappings, new LiteralEncoder(metadata)))
.collect(toImmutableList());

Assignments.Builder newProjectionAssignments = Assignments.builder();
for (int i = 0; i < project.getOutputSymbols().size(); i++) {
newProjectionAssignments.put(project.getOutputSymbols().get(i), newProjections.get(i));
// Map internal node references to new partial projections
ImmutableMap.Builder<NodeRef<Expression>, Expression> nodesToNewPartialProjectionsBuilder = ImmutableMap.builder();
for (int i = 0; i < nodesForPartialProjections.size(); i++) {
nodesToNewPartialProjectionsBuilder.put(nodesForPartialProjections.get(i), newPartialProjections.get(i));
}
Map<NodeRef<Expression>, Expression> nodesToNewPartialProjections = nodesToNewPartialProjectionsBuilder.build();

// Stitch partial translations to form new complete projections
Assignments.Builder newProjectionAssignments = Assignments.builder();
project.getAssignments().entrySet().forEach(entry -> {
newProjectionAssignments.put(entry.getKey(), replaceExpression(entry.getValue(), nodesToNewPartialProjections));
});

return Result.ofPlanNode(
new ProjectNode(
Expand Down
Loading