diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java index 4e75f6613429f..a5e59883c8581 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java @@ -144,10 +144,10 @@ public String getGraphvizPlan(Session session, Statement statement, Type planTyp switch (planType) { case LOGICAL: Plan plan = getLogicalPlan(session, statement, parameters, warningCollector); - return graphvizLogicalPlan(plan.getRoot(), plan.getTypes(), session); + return graphvizLogicalPlan(plan.getRoot(), plan.getTypes(), session, metadata.getFunctionManager()); case DISTRIBUTED: SubPlan subPlan = getDistributedPlan(session, statement, parameters, warningCollector); - return graphvizDistributedPlan(subPlan, session); + return graphvizDistributedPlan(subPlan, session, metadata.getFunctionManager()); } throw new IllegalArgumentException("Unhandled plan type: " + planType); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index e11b3c1fec299..994342b1fdcc6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -163,7 +163,7 @@ private PlanPrinter( .sum(), MILLISECONDS)); this.representation = new PlanRepresentation(planRoot, types, totalCpuTime, totalScheduledTime); - this.formatter = new RowExpressionFormatter(session.toConnectorSession()); + this.formatter = new RowExpressionFormatter(session.toConnectorSession(), functionManager); Visitor visitor = new Visitor(stageExecutionStrategy, types, estimatedStatsAndCosts, session, stats); planRoot.accept(visitor, null); @@ -312,7 +312,7 @@ private static String formatFragment(FunctionManager functionManager, Session se return builder.toString(); } - public static String graphvizLogicalPlan(PlanNode plan, TypeProvider types, Session session) + public static String graphvizLogicalPlan(PlanNode plan, TypeProvider types, Session session, FunctionManager functionManager) { // TODO: This should move to something like GraphvizRenderer PlanFragment fragment = new PlanFragment( @@ -325,12 +325,12 @@ public static String graphvizLogicalPlan(PlanNode plan, TypeProvider types, Sess StageExecutionDescriptor.ungroupedExecution(), StatsAndCosts.empty(), Optional.empty()); - return GraphvizPrinter.printLogical(ImmutableList.of(fragment), session); + return GraphvizPrinter.printLogical(ImmutableList.of(fragment), session, functionManager); } - public static String graphvizDistributedPlan(SubPlan plan, Session session) + public static String graphvizDistributedPlan(SubPlan plan, Session session, FunctionManager functionManager) { - return GraphvizPrinter.printDistributed(plan, session); + return GraphvizPrinter.printDistributed(plan, session, functionManager); } private class Visitor diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/RowExpressionFormatter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/RowExpressionFormatter.java index b98eb3d9abd96..bae8e21231f26 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/RowExpressionFormatter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/RowExpressionFormatter.java @@ -13,8 +13,11 @@ */ package com.facebook.presto.sql.planner.planPrinter; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.InputReferenceExpression; @@ -25,9 +28,11 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.LiteralInterpreter; +import com.facebook.presto.sql.relational.FunctionResolution; import java.util.List; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -35,10 +40,14 @@ public final class RowExpressionFormatter { private final ConnectorSession session; + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution standardFunctionResolution; - public RowExpressionFormatter(ConnectorSession session) + public RowExpressionFormatter(ConnectorSession session, FunctionManager functionManager) { this.session = requireNonNull(session, "session is null"); + this.functionMetadataManager = requireNonNull(functionManager, "function manager is null"); + this.standardFunctionResolution = new FunctionResolution(functionManager); } public String formatRowExpression(RowExpression expression) @@ -57,12 +66,32 @@ public class Formatter @Override public String visitCall(CallExpression node, Void context) { + if (standardFunctionResolution.isArithmeticFunction(node.getFunctionHandle()) || standardFunctionResolution.isComparisonFunction(node.getFunctionHandle())) { + String operation = functionMetadataManager.getFunctionMetadata(node.getFunctionHandle()).getOperatorType().get().getOperator(); + return String.join(" " + operation + " ", formatRowExpressions(node.getArguments()).stream().map(e -> "(" + e + ")").collect(toImmutableList())); + } + else if (standardFunctionResolution.isCastFunction(node.getFunctionHandle())) { + return String.format("CAST(%s AS %s)", formatRowExpression(node.getArguments().get(0)), node.getType().getDisplayName()); + } + else if (standardFunctionResolution.isNegateFunction(node.getFunctionHandle())) { + return "-(" + formatRowExpression(node.getArguments().get(0)) + ")"; + } + else if (standardFunctionResolution.isSubscriptFunction(node.getFunctionHandle())) { + return formatRowExpression(node.getArguments().get(0)) + "[" + formatRowExpression(node.getArguments().get(1)) + "]"; + } + else if (standardFunctionResolution.isBetweenFunction(node.getFunctionHandle())) { + List formattedExpresions = formatRowExpressions(node.getArguments()); + return String.format("%s BETWEEN (%s) AND (%s)", formattedExpresions.get(0), formattedExpresions.get(1), formattedExpresions.get(2)); + } return node.getDisplayName() + "(" + String.join(", ", formatRowExpressions(node.getArguments())) + ")"; } @Override public String visitSpecialForm(SpecialFormExpression node, Void context) { + if (node.getForm().equals(SpecialFormExpression.Form.AND) || node.getForm().equals(SpecialFormExpression.Form.OR)) { + return String.join(" " + node.getForm() + " ", formatRowExpressions(node.getArguments()).stream().map(e -> "(" + e + ")").collect(toImmutableList())); + } return node.getForm().name() + "(" + String.join(", ", formatRowExpressions(node.getArguments())) + ")"; } diff --git a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index fe522c29d6332..470a4638d7e6c 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -14,6 +14,7 @@ package com.facebook.presto.util; import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.SubPlan; @@ -133,7 +134,7 @@ private enum NodeType private GraphvizPrinter() {} - public static String printLogical(List fragments, Session session) + public static String printLogical(List fragments, Session session, FunctionManager functionManager) { Map fragmentsById = Maps.uniqueIndex(fragments, PlanFragment::getId); PlanNodeIdGenerator idGenerator = new PlanNodeIdGenerator(); @@ -142,7 +143,7 @@ public static String printLogical(List fragments, Session session) output.append("digraph logical_plan {\n"); for (PlanFragment fragment : fragments) { - printFragmentNodes(output, fragment, idGenerator, session); + printFragmentNodes(output, fragment, idGenerator, session, functionManager); } for (PlanFragment fragment : fragments) { @@ -154,7 +155,7 @@ public static String printLogical(List fragments, Session session) return output.toString(); } - public static String printDistributed(SubPlan plan, Session session) + public static String printDistributed(SubPlan plan, Session session, FunctionManager functionManager) { List fragments = plan.getAllFragments(); Map fragmentsById = Maps.uniqueIndex(fragments, PlanFragment::getId); @@ -163,25 +164,31 @@ public static String printDistributed(SubPlan plan, Session session) StringBuilder output = new StringBuilder(); output.append("digraph distributed_plan {\n"); - printSubPlan(plan, fragmentsById, idGenerator, output, session); + printSubPlan(plan, fragmentsById, idGenerator, output, session, functionManager); output.append("}\n"); return output.toString(); } - private static void printSubPlan(SubPlan plan, Map fragmentsById, PlanNodeIdGenerator idGenerator, StringBuilder output, Session session) + private static void printSubPlan( + SubPlan plan, + Map fragmentsById, + PlanNodeIdGenerator idGenerator, + StringBuilder output, + Session session, + FunctionManager functionManager) { PlanFragment fragment = plan.getFragment(); - printFragmentNodes(output, fragment, idGenerator, session); + printFragmentNodes(output, fragment, idGenerator, session, functionManager); fragment.getRoot().accept(new EdgePrinter(output, fragmentsById, idGenerator), null); for (SubPlan child : plan.getChildren()) { - printSubPlan(child, fragmentsById, idGenerator, output, session); + printSubPlan(child, fragmentsById, idGenerator, output, session, functionManager); } } - private static void printFragmentNodes(StringBuilder output, PlanFragment fragment, PlanNodeIdGenerator idGenerator, Session session) + private static void printFragmentNodes(StringBuilder output, PlanFragment fragment, PlanNodeIdGenerator idGenerator, Session session, FunctionManager functionManager) { String clusterId = "cluster_" + fragment.getId(); output.append("subgraph ") @@ -193,7 +200,7 @@ private static void printFragmentNodes(StringBuilder output, PlanFragment fragme .append('\n'); PlanNode plan = fragment.getRoot(); - plan.accept(new NodePrinter(output, idGenerator, session), null); + plan.accept(new NodePrinter(output, idGenerator, session, functionManager), null); output.append("}") .append('\n'); @@ -207,11 +214,11 @@ private static class NodePrinter private final PlanNodeIdGenerator idGenerator; private final RowExpressionFormatter formatter; - public NodePrinter(StringBuilder output, PlanNodeIdGenerator idGenerator, Session session) + public NodePrinter(StringBuilder output, PlanNodeIdGenerator idGenerator, Session session, FunctionManager functionManager) { this.output = output; this.idGenerator = idGenerator; - this.formatter = new RowExpressionFormatter(session.toConnectorSession()); + this.formatter = new RowExpressionFormatter(session.toConnectorSession(), functionManager); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionFormatter.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionFormatter.java index 64361524a89e3..1923f10d76418 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionFormatter.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionFormatter.java @@ -13,18 +13,50 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.CastType; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.block.LongArrayBlockBuilder; +import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.planner.planPrinter.RowExpressionFormatter; +import com.facebook.presto.type.TypeRegistry; +import com.google.common.collect.ImmutableList; import com.google.common.io.BaseEncoding; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import org.testng.annotations.Test; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.spi.function.OperatorType.ADD; +import static com.facebook.presto.spi.function.OperatorType.BETWEEN; +import static com.facebook.presto.spi.function.OperatorType.CAST; +import static com.facebook.presto.spi.function.OperatorType.DIVIDE; +import static com.facebook.presto.spi.function.OperatorType.EQUAL; +import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN; +import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN_OR_EQUAL; +import static com.facebook.presto.spi.function.OperatorType.HASH_CODE; +import static com.facebook.presto.spi.function.OperatorType.IS_DISTINCT_FROM; +import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; +import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; +import static com.facebook.presto.spi.function.OperatorType.MODULUS; +import static com.facebook.presto.spi.function.OperatorType.MULTIPLY; +import static com.facebook.presto.spi.function.OperatorType.NEGATION; +import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; +import static com.facebook.presto.spi.function.OperatorType.SUBSCRIPT; +import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.OR; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.CharType.createCharType; @@ -37,6 +69,8 @@ import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.constantNull; import static com.facebook.presto.type.ColorType.COLOR; @@ -46,7 +80,12 @@ public class TestRowExpressionFormatter { - private static final RowExpressionFormatter FORMATTER = new RowExpressionFormatter(TEST_SESSION.toConnectorSession()); + private static final TypeManager typeManager = new TypeRegistry(); + private static final FunctionManager functionManager = new FunctionManager(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); + private static final RowExpressionFormatter FORMATTER = new RowExpressionFormatter(TEST_SESSION.toConnectorSession(), functionManager); + private static final VariableReferenceExpression C_BIGINT = new VariableReferenceExpression("c_bigint", BIGINT); + private static final VariableReferenceExpression C_BIGINT_ARRAY = new VariableReferenceExpression("c_bigint_array", new ArrayType(BIGINT)); + @Test public void testConstants() { @@ -116,11 +155,161 @@ public void testConstants() assertEquals(format(constantExpression), "[Block: position count: 2; size: 96 bytes]"); } + @Test + public void testCalls() + { + RowExpression callExpression; + + // arithmetic + callExpression = createCallExpression(ADD); + assertEquals(format(callExpression), "(c_bigint) + (BIGINT 5)"); + callExpression = createCallExpression(SUBTRACT); + assertEquals(format(callExpression), "(c_bigint) - (BIGINT 5)"); + callExpression = createCallExpression(MULTIPLY); + assertEquals(format(callExpression), "(c_bigint) * (BIGINT 5)"); + callExpression = createCallExpression(DIVIDE); + assertEquals(format(callExpression), "(c_bigint) / (BIGINT 5)"); + callExpression = createCallExpression(MODULUS); + assertEquals(format(callExpression), "(c_bigint) % (BIGINT 5)"); + + // comparison + callExpression = createCallExpression(GREATER_THAN); + assertEquals(format(callExpression), "(c_bigint) > (BIGINT 5)"); + callExpression = createCallExpression(LESS_THAN); + assertEquals(format(callExpression), "(c_bigint) < (BIGINT 5)"); + callExpression = createCallExpression(GREATER_THAN_OR_EQUAL); + assertEquals(format(callExpression), "(c_bigint) >= (BIGINT 5)"); + callExpression = createCallExpression(LESS_THAN_OR_EQUAL); + assertEquals(format(callExpression), "(c_bigint) <= (BIGINT 5)"); + callExpression = createCallExpression(EQUAL); + assertEquals(format(callExpression), "(c_bigint) = (BIGINT 5)"); + callExpression = createCallExpression(NOT_EQUAL); + assertEquals(format(callExpression), "(c_bigint) <> (BIGINT 5)"); + callExpression = createCallExpression(IS_DISTINCT_FROM); + assertEquals(format(callExpression), "(c_bigint) IS DISTINCT FROM (BIGINT 5)"); + + // negation + RowExpression expression = createCallExpression(ADD); + callExpression = call( + NEGATION.name(), + functionManager.resolveOperator(NEGATION, fromTypes(expression.getType())), + expression.getType(), + expression); + assertEquals(format(callExpression), "-((c_bigint) + (BIGINT 5))"); + + // subscript + ArrayType arrayType = (ArrayType) C_BIGINT_ARRAY.getType(); + Type elementType = arrayType.getElementType(); + RowExpression subscriptExpression = call(SUBSCRIPT.name(), + functionManager.resolveOperator(SUBSCRIPT, fromTypes(arrayType, elementType)), + elementType, + ImmutableList.of(C_BIGINT_ARRAY, constant(0, INTEGER))); + callExpression = subscriptExpression; + assertEquals(format(callExpression), "c_bigint_array[INTEGER 0]"); + + // cast + callExpression = call( + CAST.name(), + functionManager.lookupCast(CastType.CAST, TINYINT.getTypeSignature(), BIGINT.getTypeSignature()), + BIGINT, + constant(1, TINYINT)); + assertEquals(format(callExpression), "CAST(TINYINT 1 AS bigint)"); + + // between + callExpression = call( + BETWEEN.name(), + functionManager.resolveOperator(BETWEEN, fromTypes(BIGINT, BIGINT, BIGINT)), + BOOLEAN, + subscriptExpression, + constant(1, BIGINT), + constant(5, BIGINT)); + assertEquals(format(callExpression), "c_bigint_array[INTEGER 0] BETWEEN (BIGINT 1) AND (BIGINT 5)"); + + // other + callExpression = call( + HASH_CODE.name(), + functionManager.resolveOperator(HASH_CODE, fromTypes(BIGINT)), + BIGINT, + constant(1, BIGINT)); + assertEquals(format(callExpression), "HASH_CODE(BIGINT 1)"); + } + + @Test + public void testSpecialForm() + { + RowExpression specialFormExpression; + + // or and and + specialFormExpression = new SpecialFormExpression(OR, BOOLEAN, createCallExpression(NOT_EQUAL), createCallExpression(IS_DISTINCT_FROM)); + assertEquals(format(specialFormExpression), "((c_bigint) <> (BIGINT 5)) OR ((c_bigint) IS DISTINCT FROM (BIGINT 5))"); + specialFormExpression = new SpecialFormExpression(AND, BOOLEAN, createCallExpression(EQUAL), createCallExpression(GREATER_THAN)); + assertEquals(format(specialFormExpression), "((c_bigint) = (BIGINT 5)) AND ((c_bigint) > (BIGINT 5))"); + + // other + specialFormExpression = new SpecialFormExpression(IS_NULL, BOOLEAN, createCallExpression(ADD)); + assertEquals(format(specialFormExpression), "IS_NULL((c_bigint) + (BIGINT 5))"); + } + + @Test + public void testComplex() + { + RowExpression complexExpression; + + RowExpression expression = createCallExpression(ADD); + complexExpression = call( + SUBTRACT.name(), + functionManager.resolveOperator(SUBTRACT, fromTypes(BIGINT, BIGINT)), + BIGINT, + C_BIGINT, + expression); + assertEquals(format(complexExpression), "(c_bigint) - ((c_bigint) + (BIGINT 5))"); + + RowExpression expression1 = createCallExpression(ADD); + RowExpression expression2 = call( + MULTIPLY.name(), + functionManager.resolveOperator(MULTIPLY, fromTypes(BIGINT, BIGINT)), + BIGINT, + expression1, + C_BIGINT); + RowExpression expression3 = createCallExpression(GREATER_THAN); + complexExpression = new SpecialFormExpression(OR, BOOLEAN, expression2, expression3); + assertEquals(format(complexExpression), "(((c_bigint) + (BIGINT 5)) * (c_bigint)) OR ((c_bigint) > (BIGINT 5))"); + + ArrayType arrayType = (ArrayType) C_BIGINT_ARRAY.getType(); + Type elementType = arrayType.getElementType(); + expression1 = call(SUBSCRIPT.name(), + functionManager.resolveOperator(SUBSCRIPT, fromTypes(arrayType, elementType)), + elementType, + ImmutableList.of(C_BIGINT_ARRAY, constant(5, INTEGER))); + expression2 = call( + NEGATION.name(), + functionManager.resolveOperator(NEGATION, fromTypes(expression1.getType())), + expression1.getType(), + expression1); + expression3 = call( + ADD.name(), + functionManager.resolveOperator(ADD, fromTypes(expression2.getType(), BIGINT)), + BIGINT, + expression2, + constant(5L, BIGINT)); + assertEquals(format(expression3), "(-(c_bigint_array[INTEGER 5])) + (BIGINT 5)"); + } + protected static Object decimal(String decimalString) { return Decimals.parseIncludeLeadingZerosInPrecision(decimalString).getObject(); } + private static CallExpression createCallExpression(OperatorType type) + { + return call( + type.name(), + functionManager.resolveOperator(type, fromTypes(BIGINT, BIGINT)), + BIGINT, + C_BIGINT, + constant(5L, BIGINT)); + } + private static String format(RowExpression expression) { return FORMATTER.formatRowExpression(expression);