diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java index dcfb8c4432697..f3d153e853a4e 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java @@ -43,6 +43,7 @@ import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spiller.SpillSpaceTracker; import com.facebook.presto.split.SplitSource; @@ -215,16 +216,18 @@ private static List getNextBatch(SplitSource splitSource) protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNodeId planNodeId, List types) { ImmutableMap.Builder symbolTypes = ImmutableMap.builder(); - ImmutableMap.Builder symbolToInputMapping = ImmutableMap.builder(); + ImmutableList.Builder variables = ImmutableList.builder(); + ImmutableMap.Builder variableToInputMapping = ImmutableMap.builder(); ImmutableList.Builder projections = ImmutableList.builder(); for (int channel = 0; channel < types.size(); channel++) { - Symbol symbol = new Symbol("h" + channel); - symbolTypes.put(symbol, types.get(channel)); - symbolToInputMapping.put(symbol, channel); + VariableReferenceExpression variable = new VariableReferenceExpression("h" + channel, types.get(channel)); + symbolTypes.put(new Symbol(variable.getName()), types.get(channel)); + variables.add(variable); + variableToInputMapping.put(variable, channel); projections.add(new InputPageProjection(channel, types.get(channel))); } - Optional hashExpression = HashGenerationOptimizer.getHashExpression(ImmutableList.copyOf(symbolTypes.build().keySet())); + Optional hashExpression = HashGenerationOptimizer.getHashExpression(variables.build()); verify(hashExpression.isPresent()); Map, Type> expressionTypes = getExpressionTypes( session, @@ -234,7 +237,7 @@ protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNo hashExpression.get(), ImmutableList.of(), WarningCollector.NOOP); - RowExpression translated = translate(hashExpression.get(), expressionTypes, symbolToInputMapping.build(), localQueryRunner.getMetadata().getFunctionManager(), localQueryRunner.getTypeManager(), session, false); + RowExpression translated = translate(hashExpression.get(), expressionTypes, variableToInputMapping.build(), localQueryRunner.getMetadata().getFunctionManager(), localQueryRunner.getTypeManager(), session, false); PageFunctionCompiler functionCompiler = new PageFunctionCompiler(localQueryRunner.getMetadata(), 0); projections.add(functionCompiler.compileProjection(translated, Optional.empty()).get()); diff --git a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q31.plan.txt b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q31.plan.txt index 83ff150ff6730..2ac018bdac173 100644 --- a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q31.plan.txt +++ b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q31.plan.txt @@ -18,7 +18,7 @@ remote exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan customer_address local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["ca_county_173", NullableValue{type=integer, value=2000}, NullableValue{type=integer, value=2}]) + remote exchange (REPARTITION, HASH, ["ca_county_173", 2, 2000]) final aggregation over (ca_county_173, d_qoy_148, d_year_144) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ca_county_173", "d_qoy_148", "d_year_144"]) @@ -47,7 +47,7 @@ remote exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan customer_address local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["ca_county_448", NullableValue{type=integer, value=2000}, NullableValue{type=integer, value=2}]) + remote exchange (REPARTITION, HASH, ["ca_county_448", 2, 2000]) final aggregation over (ca_county_448, d_qoy_423, d_year_419) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ca_county_448", "d_qoy_423", "d_year_419"]) @@ -62,7 +62,7 @@ remote exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan customer_address local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["ca_county", NullableValue{type=integer, value=2000}, NullableValue{type=integer, value=2}]) + remote exchange (REPARTITION, HASH, ["ca_county", 2, 2000]) join (INNER, PARTITIONED): final aggregation over (ca_county, d_qoy, d_year) local exchange (GATHER, SINGLE, []) diff --git a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestExtractSpatialInnerJoin.java b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestExtractSpatialInnerJoin.java index e1dad42e5fa29..d98f1fad7e3d4 100644 --- a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestExtractSpatialInnerJoin.java +++ b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestExtractSpatialInnerJoin.java @@ -47,7 +47,7 @@ public void testDoesNotFire() p.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText('POLYGON ...'), b)"), p.join(INNER, p.values(), - p.values(p.symbol("b"))))) + p.values(p.variable("b"))))) .doesNotFire(); // OR operand @@ -55,8 +55,8 @@ public void testDoesNotFire() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), point) OR name_1 != name_2"), p.join(INNER, - p.values(p.symbol("wkt", VARCHAR), p.symbol("name_1")), - p.values(p.symbol("point", GEOMETRY), p.symbol("name_2"))))) + p.values(p.variable("wkt", VARCHAR), p.variable("name_1")), + p.values(p.variable("point", GEOMETRY), p.variable("name_2"))))) .doesNotFire(); // NOT operator @@ -64,8 +64,8 @@ public void testDoesNotFire() .on(p -> p.filter(PlanBuilder.expression("NOT ST_Contains(ST_GeometryFromText(wkt), point)"), p.join(INNER, - p.values(p.symbol("wkt", VARCHAR), p.symbol("name_1")), - p.values(p.symbol("point", GEOMETRY), p.symbol("name_2"))))) + p.values(p.variable("wkt", VARCHAR), p.variable("name_1")), + p.values(p.variable("point", GEOMETRY), p.variable("name_2"))))) .doesNotFire(); // ST_Distance(...) > r @@ -73,8 +73,8 @@ public void testDoesNotFire() .on(p -> p.filter(PlanBuilder.expression("ST_Distance(a, b) > 5"), p.join(INNER, - p.values(p.symbol("a", GEOMETRY)), - p.values(p.symbol("b", GEOMETRY))))) + p.values(p.variable("a", GEOMETRY)), + p.values(p.variable("b", GEOMETRY))))) .doesNotFire(); // SphericalGeography operand @@ -82,16 +82,16 @@ public void testDoesNotFire() .on(p -> p.filter(PlanBuilder.expression("ST_Distance(a, b) < 5"), p.join(INNER, - p.values(p.symbol("a", SPHERICAL_GEOGRAPHY)), - p.values(p.symbol("b", SPHERICAL_GEOGRAPHY))))) + p.values(p.variable("a", SPHERICAL_GEOGRAPHY)), + p.values(p.variable("b", SPHERICAL_GEOGRAPHY))))) .doesNotFire(); assertRuleApplication() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(polygon, point)"), p.join(INNER, - p.values(p.symbol("polygon", SPHERICAL_GEOGRAPHY)), - p.values(p.symbol("point", SPHERICAL_GEOGRAPHY))))) + p.values(p.variable("polygon", SPHERICAL_GEOGRAPHY)), + p.values(p.variable("point", SPHERICAL_GEOGRAPHY))))) .doesNotFire(); // to_spherical_geography() operand @@ -99,16 +99,16 @@ public void testDoesNotFire() .on(p -> p.filter(PlanBuilder.expression("ST_Distance(to_spherical_geography(ST_GeometryFromText(wkt)), point) < 5"), p.join(INNER, - p.values(p.symbol("wkt", VARCHAR)), - p.values(p.symbol("point", SPHERICAL_GEOGRAPHY))))) + p.values(p.variable("wkt", VARCHAR)), + p.values(p.variable("point", SPHERICAL_GEOGRAPHY))))) .doesNotFire(); assertRuleApplication() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(to_spherical_geography(ST_GeometryFromText(wkt)), point)"), p.join(INNER, - p.values(p.symbol("wkt", VARCHAR)), - p.values(p.symbol("point", SPHERICAL_GEOGRAPHY))))) + p.values(p.variable("wkt", VARCHAR)), + p.values(p.variable("point", SPHERICAL_GEOGRAPHY))))) .doesNotFire(); } @@ -187,8 +187,8 @@ private void testSimpleDistanceQuery(String filter, String newFilter) .on(p -> p.filter(PlanBuilder.expression(filter), p.join(INNER, - p.values(p.symbol("a", GEOMETRY), p.symbol("name_a")), - p.values(p.symbol("b", GEOMETRY), p.symbol("name_b"), p.symbol("r"))))) + p.values(p.variable("a", GEOMETRY), p.variable("name_a")), + p.values(p.variable("b", GEOMETRY), p.variable("name_b"), p.variable("r"))))) .matches( spatialJoin(newFilter, values(ImmutableMap.of("a", 0, "name_a", 1)), @@ -201,8 +201,8 @@ private void testRadiusExpressionInDistanceQuery(String filter, String newFilter .on(p -> p.filter(PlanBuilder.expression(filter), p.join(INNER, - p.values(p.symbol("a", GEOMETRY), p.symbol("name_a")), - p.values(p.symbol("b", GEOMETRY), p.symbol("name_b"), p.symbol("r"))))) + p.values(p.variable("a", GEOMETRY), p.variable("name_a")), + p.values(p.variable("b", GEOMETRY), p.variable("name_b"), p.variable("r"))))) .matches( spatialJoin(newFilter, values(ImmutableMap.of("a", 0, "name_a", 1)), @@ -216,8 +216,8 @@ private void testPointExpressionsInDistanceQuery(String filter, String newFilter .on(p -> p.filter(PlanBuilder.expression(filter), p.join(INNER, - p.values(p.symbol("lat_a"), p.symbol("lng_a"), p.symbol("name_a")), - p.values(p.symbol("lat_b"), p.symbol("lng_b"), p.symbol("name_b"))))) + p.values(p.variable("lat_a"), p.variable("lng_a"), p.variable("name_a")), + p.values(p.variable("lat_b"), p.variable("lng_b"), p.variable("name_b"))))) .matches( spatialJoin(newFilter, project(ImmutableMap.of("point_a", expression("ST_Point(lng_a, lat_a)")), @@ -232,8 +232,8 @@ private void testPointAndRadiusExpressionsInDistanceQuery(String filter, String .on(p -> p.filter(PlanBuilder.expression(filter), p.join(INNER, - p.values(p.symbol("lat_a"), p.symbol("lng_a"), p.symbol("name_a")), - p.values(p.symbol("lat_b"), p.symbol("lng_b"), p.symbol("name_b"))))) + p.values(p.variable("lat_a"), p.variable("lng_a"), p.variable("name_a")), + p.values(p.variable("lat_b"), p.variable("lng_b"), p.variable("name_b"))))) .matches( spatialJoin(newFilter, project(ImmutableMap.of("point_a", expression("ST_Point(lng_a, lat_a)")), @@ -251,8 +251,8 @@ public void testConvertToSpatialJoin() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(a, b)"), p.join(INNER, - p.values(p.symbol("a")), - p.values(p.symbol("b"))))) + p.values(p.variable("a")), + p.values(p.variable("b"))))) .matches( spatialJoin("ST_Contains(a, b)", values(ImmutableMap.of("a", 0)), @@ -263,8 +263,8 @@ public void testConvertToSpatialJoin() .on(p -> p.filter(PlanBuilder.expression("name_1 != name_2 AND ST_Contains(a, b)"), p.join(INNER, - p.values(p.symbol("a"), p.symbol("name_1")), - p.values(p.symbol("b"), p.symbol("name_2"))))) + p.values(p.variable("a"), p.variable("name_1")), + p.values(p.variable("b"), p.variable("name_2"))))) .matches( spatialJoin("name_1 != name_2 AND ST_Contains(a, b)", values(ImmutableMap.of("a", 0, "name_1", 1)), @@ -275,8 +275,8 @@ public void testConvertToSpatialJoin() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(a1, b1) AND ST_Contains(a2, b2)"), p.join(INNER, - p.values(p.symbol("a1"), p.symbol("a2")), - p.values(p.symbol("b1"), p.symbol("b2"))))) + p.values(p.variable("a1"), p.variable("a2")), + p.values(p.variable("b1"), p.variable("b2"))))) .matches( spatialJoin("ST_Contains(a1, b1) AND ST_Contains(a2, b2)", values(ImmutableMap.of("a1", 0, "a2", 1)), @@ -290,8 +290,8 @@ public void testPushDownFirstArgument() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), point)"), p.join(INNER, - p.values(p.symbol("wkt", VARCHAR)), - p.values(p.symbol("point", GEOMETRY))))) + p.values(p.variable("wkt", VARCHAR)), + p.values(p.variable("point", GEOMETRY))))) .matches( spatialJoin("ST_Contains(st_geometryfromtext, point)", project(ImmutableMap.of("st_geometryfromtext", expression("ST_GeometryFromText(wkt)")), values(ImmutableMap.of("wkt", 0))), @@ -301,7 +301,7 @@ public void testPushDownFirstArgument() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(0, 0))"), p.join(INNER, - p.values(p.symbol("wkt", VARCHAR)), + p.values(p.variable("wkt", VARCHAR)), p.values()))) .doesNotFire(); } @@ -313,8 +313,8 @@ public void testPushDownSecondArgument() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(polygon, ST_Point(lng, lat))"), p.join(INNER, - p.values(p.symbol("polygon", GEOMETRY)), - p.values(p.symbol("lat"), p.symbol("lng"))))) + p.values(p.variable("polygon", GEOMETRY)), + p.values(p.variable("lat"), p.variable("lng"))))) .matches( spatialJoin("ST_Contains(polygon, st_point)", values(ImmutableMap.of("polygon", 0)), @@ -325,7 +325,7 @@ public void testPushDownSecondArgument() p.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText('POLYGON ...'), ST_Point(lng, lat))"), p.join(INNER, p.values(), - p.values(p.symbol("lat"), p.symbol("lng"))))) + p.values(p.variable("lat"), p.variable("lng"))))) .doesNotFire(); } @@ -336,8 +336,8 @@ public void testPushDownBothArguments() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))"), p.join(INNER, - p.values(p.symbol("wkt", VARCHAR)), - p.values(p.symbol("lat"), p.symbol("lng"))))) + p.values(p.variable("wkt", VARCHAR)), + p.values(p.variable("lat"), p.variable("lng"))))) .matches( spatialJoin("ST_Contains(st_geometryfromtext, st_point)", project(ImmutableMap.of("st_geometryfromtext", expression("ST_GeometryFromText(wkt)")), values(ImmutableMap.of("wkt", 0))), @@ -351,8 +351,8 @@ public void testPushDownOppositeOrder() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))"), p.join(INNER, - p.values(p.symbol("lat"), p.symbol("lng")), - p.values(p.symbol("wkt", VARCHAR))))) + p.values(p.variable("lat"), p.variable("lng")), + p.values(p.variable("wkt", VARCHAR))))) .matches( spatialJoin("ST_Contains(st_geometryfromtext, st_point)", project(ImmutableMap.of("st_point", expression("ST_Point(lng, lat)")), values(ImmutableMap.of("lat", 0, "lng", 1))), @@ -366,8 +366,8 @@ public void testPushDownAnd() .on(p -> p.filter(PlanBuilder.expression("name_1 != name_2 AND ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))"), p.join(INNER, - p.values(p.symbol("wkt", VARCHAR), p.symbol("name_1")), - p.values(p.symbol("lat"), p.symbol("lng"), p.symbol("name_2"))))) + p.values(p.variable("wkt", VARCHAR), p.variable("name_1")), + p.values(p.variable("lat"), p.variable("lng"), p.variable("name_2"))))) .matches( spatialJoin("name_1 != name_2 AND ST_Contains(st_geometryfromtext, st_point)", project(ImmutableMap.of("st_geometryfromtext", expression("ST_GeometryFromText(wkt)")), values(ImmutableMap.of("wkt", 0, "name_1", 1))), @@ -378,8 +378,8 @@ public void testPushDownAnd() .on(p -> p.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt1), geometry1) AND ST_Contains(ST_GeometryFromText(wkt2), geometry2)"), p.join(INNER, - p.values(p.symbol("wkt1", VARCHAR), p.symbol("wkt2", VARCHAR)), - p.values(p.symbol("geometry1"), p.symbol("geometry2"))))) + p.values(p.variable("wkt1", VARCHAR), p.variable("wkt2", VARCHAR)), + p.values(p.variable("geometry1"), p.variable("geometry2"))))) .matches( spatialJoin("ST_Contains(st_geometryfromtext, geometry1) AND ST_Contains(ST_GeometryFromText(wkt2), geometry2)", project(ImmutableMap.of("st_geometryfromtext", expression("ST_GeometryFromText(wkt1)")), values(ImmutableMap.of("wkt1", 0, "wkt2", 1))), diff --git a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestExtractSpatialLeftJoin.java b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestExtractSpatialLeftJoin.java index 98f43ad4719b5..5b526ff9aac43 100644 --- a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestExtractSpatialLeftJoin.java +++ b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestExtractSpatialLeftJoin.java @@ -46,7 +46,7 @@ public void testDoesNotFire() .on(p -> p.join(LEFT, p.values(), - p.values(p.symbol("b")), + p.values(p.variable("b")), expression("ST_Contains(ST_GeometryFromText('POLYGON ...'), b)"))) .doesNotFire(); @@ -54,8 +54,8 @@ public void testDoesNotFire() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("wkt", VARCHAR), p.symbol("name_1")), - p.values(p.symbol("point", GEOMETRY), p.symbol("name_2")), + p.values(p.variable("wkt", VARCHAR), p.variable("name_1")), + p.values(p.variable("point", GEOMETRY), p.variable("name_2")), expression("ST_Contains(ST_GeometryFromText(wkt), point) OR name_1 != name_2"))) .doesNotFire(); @@ -63,8 +63,8 @@ public void testDoesNotFire() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("wkt", VARCHAR), p.symbol("name_1")), - p.values(p.symbol("point", GEOMETRY), p.symbol("name_2")), + p.values(p.variable("wkt", VARCHAR), p.variable("name_1")), + p.values(p.variable("point", GEOMETRY), p.variable("name_2")), expression("NOT ST_Contains(ST_GeometryFromText(wkt), point)"))) .doesNotFire(); @@ -72,8 +72,8 @@ public void testDoesNotFire() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("a", GEOMETRY)), - p.values(p.symbol("b", GEOMETRY)), + p.values(p.variable("a", GEOMETRY)), + p.values(p.variable("b", GEOMETRY)), expression("ST_Distance(a, b) > 5"))) .doesNotFire(); @@ -81,16 +81,16 @@ public void testDoesNotFire() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("a", SPHERICAL_GEOGRAPHY)), - p.values(p.symbol("b", SPHERICAL_GEOGRAPHY)), + p.values(p.variable("a", SPHERICAL_GEOGRAPHY)), + p.values(p.variable("b", SPHERICAL_GEOGRAPHY)), expression("ST_Distance(a, b) < 5"))) .doesNotFire(); assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("polygon", SPHERICAL_GEOGRAPHY)), - p.values(p.symbol("point", SPHERICAL_GEOGRAPHY)), + p.values(p.variable("polygon", SPHERICAL_GEOGRAPHY)), + p.values(p.variable("point", SPHERICAL_GEOGRAPHY)), expression("ST_Contains(polygon, point)"))) .doesNotFire(); @@ -98,16 +98,16 @@ public void testDoesNotFire() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("wkt", VARCHAR)), - p.values(p.symbol("point", SPHERICAL_GEOGRAPHY)), + p.values(p.variable("wkt", VARCHAR)), + p.values(p.variable("point", SPHERICAL_GEOGRAPHY)), expression("ST_Distance(to_spherical_geography(ST_GeometryFromText(wkt)), point) < 5"))) .doesNotFire(); assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("wkt", VARCHAR)), - p.values(p.symbol("point", SPHERICAL_GEOGRAPHY)), + p.values(p.variable("wkt", VARCHAR)), + p.values(p.variable("point", SPHERICAL_GEOGRAPHY)), expression("ST_Contains(to_spherical_geography(ST_GeometryFromText(wkt)), point)"))) .doesNotFire(); } @@ -119,8 +119,8 @@ public void testConvertToSpatialJoin() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("a")), - p.values(p.symbol("b")), + p.values(p.variable("a")), + p.values(p.variable("b")), p.expression("ST_Contains(a, b)"))) .matches( spatialLeftJoin("ST_Contains(a, b)", @@ -131,8 +131,8 @@ public void testConvertToSpatialJoin() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("a"), p.symbol("name_1")), - p.values(p.symbol("b"), p.symbol("name_2")), + p.values(p.variable("a"), p.variable("name_1")), + p.values(p.variable("b"), p.variable("name_2")), p.expression("name_1 != name_2 AND ST_Contains(a, b)"))) .matches( spatialLeftJoin("name_1 != name_2 AND ST_Contains(a, b)", @@ -143,8 +143,8 @@ public void testConvertToSpatialJoin() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("a1"), p.symbol("a2")), - p.values(p.symbol("b1"), p.symbol("b2")), + p.values(p.variable("a1"), p.variable("a2")), + p.values(p.variable("b1"), p.variable("b2")), p.expression("ST_Contains(a1, b1) AND ST_Contains(a2, b2)"))) .matches( spatialLeftJoin("ST_Contains(a1, b1) AND ST_Contains(a2, b2)", @@ -158,8 +158,8 @@ public void testPushDownFirstArgument() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("wkt", VARCHAR)), - p.values(p.symbol("point", GEOMETRY)), + p.values(p.variable("wkt", VARCHAR)), + p.values(p.variable("point", GEOMETRY)), expression("ST_Contains(ST_GeometryFromText(wkt), point)"))) .matches( spatialLeftJoin("ST_Contains(st_geometryfromtext, point)", @@ -169,7 +169,7 @@ public void testPushDownFirstArgument() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("wkt", VARCHAR)), + p.values(p.variable("wkt", VARCHAR)), p.values(), expression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(0, 0))"))) .doesNotFire(); @@ -181,8 +181,8 @@ public void testPushDownSecondArgument() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("polygon", GEOMETRY)), - p.values(p.symbol("lat"), p.symbol("lng")), + p.values(p.variable("polygon", GEOMETRY)), + p.values(p.variable("lat"), p.variable("lng")), expression("ST_Contains(polygon, ST_Point(lng, lat))"))) .matches( spatialLeftJoin("ST_Contains(polygon, st_point)", @@ -193,7 +193,7 @@ public void testPushDownSecondArgument() .on(p -> p.join(LEFT, p.values(), - p.values(p.symbol("lat"), p.symbol("lng")), + p.values(p.variable("lat"), p.variable("lng")), expression("ST_Contains(ST_GeometryFromText('POLYGON ...'), ST_Point(lng, lat))"))) .doesNotFire(); } @@ -204,8 +204,8 @@ public void testPushDownBothArguments() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("wkt", VARCHAR)), - p.values(p.symbol("lat"), p.symbol("lng")), + p.values(p.variable("wkt", VARCHAR)), + p.values(p.variable("lat"), p.variable("lng")), expression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))"))) .matches( spatialLeftJoin("ST_Contains(st_geometryfromtext, st_point)", @@ -219,8 +219,8 @@ public void testPushDownOppositeOrder() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("lat"), p.symbol("lng")), - p.values(p.symbol("wkt", VARCHAR)), + p.values(p.variable("lat"), p.variable("lng")), + p.values(p.variable("wkt", VARCHAR)), expression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))"))) .matches( spatialLeftJoin("ST_Contains(st_geometryfromtext, st_point)", @@ -234,8 +234,8 @@ public void testPushDownAnd() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("wkt", VARCHAR), p.symbol("name_1")), - p.values(p.symbol("lat"), p.symbol("lng"), p.symbol("name_2")), + p.values(p.variable("wkt", VARCHAR), p.variable("name_1")), + p.values(p.variable("lat"), p.variable("lng"), p.variable("name_2")), expression("name_1 != name_2 AND ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))"))) .matches( spatialLeftJoin("name_1 != name_2 AND ST_Contains(st_geometryfromtext, st_point)", @@ -246,8 +246,8 @@ public void testPushDownAnd() assertRuleApplication() .on(p -> p.join(LEFT, - p.values(p.symbol("wkt1", VARCHAR), p.symbol("wkt2", VARCHAR)), - p.values(p.symbol("geometry1"), p.symbol("geometry2")), + p.values(p.variable("wkt1", VARCHAR), p.variable("wkt2", VARCHAR)), + p.values(p.variable("geometry1"), p.variable("geometry2")), expression("ST_Contains(ST_GeometryFromText(wkt1), geometry1) AND ST_Contains(ST_GeometryFromText(wkt2), geometry2)"))) .matches( spatialLeftJoin("ST_Contains(st_geometryfromtext, geometry1) AND ST_Contains(ST_GeometryFromText(wkt2), geometry2)", diff --git a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java index 189d4ee7e7ca6..662f11dbe1d00 100644 --- a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java +++ b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java @@ -44,8 +44,8 @@ public void testDoesNotFire() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(geometry, 10)"), ImmutableList.of(GEOMETRY)) - .source(p.values(p.symbol("geometry"))))) + .addAggregation(p.variable(p.symbol("sp")), PlanBuilder.expression("spatial_partitioning(geometry, 10)"), ImmutableList.of(GEOMETRY)) + .source(p.values(p.variable("geometry"))))) .doesNotFire(); } @@ -56,8 +56,8 @@ public void test() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(geometry)"), ImmutableList.of(GEOMETRY)) - .source(p.values(p.symbol("geometry"))))) + .addAggregation(p.variable(p.symbol("sp")), PlanBuilder.expression("spatial_partitioning(geometry)"), ImmutableList.of(GEOMETRY)) + .source(p.values(p.variable("geometry"))))) .matches( aggregation( ImmutableMap.of("sp", functionCall("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), @@ -70,8 +70,8 @@ public void test() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(ST_Envelope(geometry))"), ImmutableList.of(GEOMETRY)) - .source(p.values(p.symbol("geometry"))))) + .addAggregation(p.variable(p.symbol("sp")), PlanBuilder.expression("spatial_partitioning(ST_Envelope(geometry))"), ImmutableList.of(GEOMETRY)) + .source(p.values(p.variable("geometry"))))) .matches( aggregation( ImmutableMap.of("sp", functionCall("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), diff --git a/presto-main/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java index 6f69e0bd178c3..dbd7016c64c8c 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -63,12 +63,12 @@ protected Optional doCalculate(AggregationNode node, Stat node.getAggregations())); } - public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection groupBySymbols, Map aggregations) + public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection groupByVariables, Map aggregations) { PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); - for (Symbol groupBySymbol : groupBySymbols) { - SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol); - result.addSymbolStatistics(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> { + for (VariableReferenceExpression groupByVariable : groupByVariables) { + VariableStatsEstimate symbolStatistics = sourceStats.getVariableStatistics(groupByVariable); + result.addVariableStatistics(groupByVariable, symbolStatistics.mapNullsFraction(nullsFraction -> { if (nullsFraction == 0.0) { return 0.0; } @@ -77,26 +77,26 @@ public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, C } double rowsCount = 1; - for (Symbol groupBySymbol : groupBySymbols) { - SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol); + for (VariableReferenceExpression groupByVariable : groupByVariables) { + VariableStatsEstimate symbolStatistics = sourceStats.getVariableStatistics(groupByVariable); int nullRow = (symbolStatistics.getNullsFraction() == 0.0) ? 0 : 1; rowsCount *= symbolStatistics.getDistinctValuesCount() + nullRow; } result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount())); - for (Map.Entry aggregationEntry : aggregations.entrySet()) { - result.addSymbolStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats)); + for (Map.Entry aggregationEntry : aggregations.entrySet()) { + result.addVariableStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats)); } return result.build(); } - private static SymbolStatsEstimate estimateAggregationStats(Aggregation aggregation, PlanNodeStatsEstimate sourceStats) + private static VariableStatsEstimate estimateAggregationStats(Aggregation aggregation, PlanNodeStatsEstimate sourceStats) { requireNonNull(aggregation, "aggregation is null"); requireNonNull(sourceStats, "sourceStats is null"); // TODO implement simple aggregations like: min, max, count, sum - return SymbolStatsEstimate.unknown(); + return VariableStatsEstimate.unknown(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/AssignUniqueIdStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/AssignUniqueIdStatsRule.java index 153922f37a9ab..225dd2b5ef595 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/AssignUniqueIdStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/AssignUniqueIdStatsRule.java @@ -40,7 +40,7 @@ public Optional calculate(AssignUniqueId assignUniqueId, { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(assignUniqueId.getSource()); return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats) - .addSymbolStatistics(assignUniqueId.getIdColumn(), SymbolStatsEstimate.builder() + .addVariableStatistics(assignUniqueId.getIdVariable(), VariableStatsEstimate.builder() .setDistinctValuesCount(sourceStats.getOutputRowCount()) .setNullsFraction(0.0) .setAverageRowSize(BIGINT.getFixedSize()) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java index 5625074e25695..f695c5abb8dce 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import java.util.Optional; import java.util.OptionalDouble; -import static com.facebook.presto.cost.SymbolStatsEstimate.buildFrom; +import static com.facebook.presto.cost.VariableStatsEstimate.buildFrom; import static com.facebook.presto.util.MoreMath.firstNonNaN; import static com.facebook.presto.util.MoreMath.max; import static com.facebook.presto.util.MoreMath.min; @@ -35,22 +35,22 @@ private ComparisonStatsCalculator() {} public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison( PlanNodeStatsEstimate inputStatistics, - SymbolStatsEstimate expressionStatistics, - Optional expressionSymbol, + VariableStatsEstimate expressionStatistics, + Optional expressionVariable, OptionalDouble literalValue, ComparisonExpression.Operator operator) { switch (operator) { case EQUAL: - return estimateExpressionEqualToLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue); + return estimateExpressionEqualToLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue); case NOT_EQUAL: - return estimateExpressionNotEqualToLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue); + return estimateExpressionNotEqualToLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue); case LESS_THAN: case LESS_THAN_OR_EQUAL: - return estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue); + return estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue); case GREATER_THAN: case GREATER_THAN_OR_EQUAL: - return estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue); + return estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue); case IS_DISTINCT_FROM: return PlanNodeStatsEstimate.unknown(); default: @@ -60,8 +60,8 @@ public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison( private static PlanNodeStatsEstimate estimateExpressionEqualToLiteral( PlanNodeStatsEstimate inputStatistics, - SymbolStatsEstimate expressionStatistics, - Optional expressionSymbol, + VariableStatsEstimate expressionStatistics, + Optional expressionVariable, OptionalDouble literalValue) { StatisticRange filterRange; @@ -71,13 +71,13 @@ private static PlanNodeStatsEstimate estimateExpressionEqualToLiteral( else { filterRange = new StatisticRange(NEGATIVE_INFINITY, POSITIVE_INFINITY, 1); } - return estimateFilterRange(inputStatistics, expressionStatistics, expressionSymbol, filterRange); + return estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange); } private static PlanNodeStatsEstimate estimateExpressionNotEqualToLiteral( PlanNodeStatsEstimate inputStatistics, - SymbolStatsEstimate expressionStatistics, - Optional expressionSymbol, + VariableStatsEstimate expressionStatistics, + Optional expressionVariable, OptionalDouble literalValue) { StatisticRange expressionRange = StatisticRange.from(expressionStatistics); @@ -94,40 +94,40 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToLiteral( PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics); estimate.setOutputRowCount(filterFactor * (1 - expressionStatistics.getNullsFraction()) * inputStatistics.getOutputRowCount()); - if (expressionSymbol.isPresent()) { - SymbolStatsEstimate symbolNewEstimate = buildFrom(expressionStatistics) + if (expressionVariable.isPresent()) { + VariableStatsEstimate symbolNewEstimate = buildFrom(expressionStatistics) .setNullsFraction(0.0) .setDistinctValuesCount(max(expressionStatistics.getDistinctValuesCount() - 1, 0)) .build(); - estimate = estimate.addSymbolStatistics(expressionSymbol.get(), symbolNewEstimate); + estimate = estimate.addVariableStatistics(expressionVariable.get(), symbolNewEstimate); } return estimate.build(); } private static PlanNodeStatsEstimate estimateExpressionLessThanLiteral( PlanNodeStatsEstimate inputStatistics, - SymbolStatsEstimate expressionStatistics, - Optional expressionSymbol, + VariableStatsEstimate expressionStatistics, + Optional expressionVariable, OptionalDouble literalValue) { StatisticRange filterRange = new StatisticRange(NEGATIVE_INFINITY, literalValue.orElse(POSITIVE_INFINITY), NaN); - return estimateFilterRange(inputStatistics, expressionStatistics, expressionSymbol, filterRange); + return estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange); } private static PlanNodeStatsEstimate estimateExpressionGreaterThanLiteral( PlanNodeStatsEstimate inputStatistics, - SymbolStatsEstimate expressionStatistics, - Optional expressionSymbol, + VariableStatsEstimate expressionStatistics, + Optional expressionVariable, OptionalDouble literalValue) { StatisticRange filterRange = new StatisticRange(literalValue.orElse(NEGATIVE_INFINITY), POSITIVE_INFINITY, NaN); - return estimateFilterRange(inputStatistics, expressionStatistics, expressionSymbol, filterRange); + return estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange); } private static PlanNodeStatsEstimate estimateFilterRange( PlanNodeStatsEstimate inputStatistics, - SymbolStatsEstimate expressionStatistics, - Optional expressionSymbol, + VariableStatsEstimate expressionStatistics, + Optional expressionVariable, StatisticRange filterRange) { StatisticRange expressionRange = StatisticRange.from(expressionStatistics); @@ -136,31 +136,31 @@ private static PlanNodeStatsEstimate estimateFilterRange( double filterFactor = expressionRange.overlapPercentWith(intersectRange); PlanNodeStatsEstimate estimate = inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1 - expressionStatistics.getNullsFraction()) * rowCount); - if (expressionSymbol.isPresent()) { - SymbolStatsEstimate symbolNewEstimate = - SymbolStatsEstimate.builder() + if (expressionVariable.isPresent()) { + VariableStatsEstimate symbolNewEstimate = + VariableStatsEstimate.builder() .setAverageRowSize(expressionStatistics.getAverageRowSize()) .setStatisticsRange(intersectRange) .setNullsFraction(0.0) .build(); - estimate = estimate.mapSymbolColumnStatistics(expressionSymbol.get(), oldStats -> symbolNewEstimate); + estimate = estimate.mapVariableColumnStatistics(expressionVariable.get(), oldStats -> symbolNewEstimate); } return estimate; } public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison( PlanNodeStatsEstimate inputStatistics, - SymbolStatsEstimate leftExpressionStatistics, - Optional leftExpressionSymbol, - SymbolStatsEstimate rightExpressionStatistics, - Optional rightExpressionSymbol, + VariableStatsEstimate leftExpressionStatistics, + Optional leftExpressionVariable, + VariableStatsEstimate rightExpressionStatistics, + Optional rightExpressionVariable, ComparisonExpression.Operator operator) { switch (operator) { case EQUAL: - return estimateExpressionEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionSymbol, rightExpressionStatistics, rightExpressionSymbol); + return estimateExpressionEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionVariable, rightExpressionStatistics, rightExpressionVariable); case NOT_EQUAL: - return estimateExpressionNotEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionSymbol, rightExpressionStatistics, rightExpressionSymbol); + return estimateExpressionNotEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionVariable, rightExpressionStatistics, rightExpressionVariable); case LESS_THAN: case LESS_THAN_OR_EQUAL: case GREATER_THAN: @@ -174,10 +174,10 @@ public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison( private static PlanNodeStatsEstimate estimateExpressionEqualToExpression( PlanNodeStatsEstimate inputStatistics, - SymbolStatsEstimate leftExpressionStatistics, - Optional leftExpressionSymbol, - SymbolStatsEstimate rightExpressionStatistics, - Optional rightExpressionSymbol) + VariableStatsEstimate leftExpressionStatistics, + Optional leftExpressionVariable, + VariableStatsEstimate rightExpressionStatistics, + Optional rightExpressionVariable) { if (isNaN(leftExpressionStatistics.getDistinctValuesCount()) || isNaN(rightExpressionStatistics.getDistinctValuesCount())) { return PlanNodeStatsEstimate.unknown(); @@ -197,36 +197,36 @@ private static PlanNodeStatsEstimate estimateExpressionEqualToExpression( PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics) .setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor * filterFactor); - SymbolStatsEstimate equalityStats = SymbolStatsEstimate.builder() + VariableStatsEstimate equalityStats = VariableStatsEstimate.builder() .setAverageRowSize(averageExcludingNaNs(leftExpressionStatistics.getAverageRowSize(), rightExpressionStatistics.getAverageRowSize())) .setNullsFraction(0) .setStatisticsRange(intersect) .setDistinctValuesCount(retainedNdv) .build(); - leftExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics(symbol, equalityStats)); - rightExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics(symbol, equalityStats)); + leftExpressionVariable.ifPresent(variable -> estimate.addVariableStatistics(variable, equalityStats)); + rightExpressionVariable.ifPresent(variable -> estimate.addVariableStatistics(variable, equalityStats)); return estimate.build(); } private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression( PlanNodeStatsEstimate inputStatistics, - SymbolStatsEstimate leftExpressionStatistics, - Optional leftExpressionSymbol, - SymbolStatsEstimate rightExpressionStatistics, - Optional rightExpressionSymbol) + VariableStatsEstimate leftExpressionStatistics, + Optional leftExpressionVariable, + VariableStatsEstimate rightExpressionStatistics, + Optional rightExpressionVariable) { double nullsFilterFactor = (1 - leftExpressionStatistics.getNullsFraction()) * (1 - rightExpressionStatistics.getNullsFraction()); PlanNodeStatsEstimate inputNullsFiltered = inputStatistics.mapOutputRowCount(size -> size * nullsFilterFactor); - SymbolStatsEstimate leftNullsFiltered = leftExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0); - SymbolStatsEstimate rightNullsFiltered = rightExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0); + VariableStatsEstimate leftNullsFiltered = leftExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0); + VariableStatsEstimate rightNullsFiltered = rightExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0); PlanNodeStatsEstimate equalityStats = estimateExpressionEqualToExpression( inputNullsFiltered, leftNullsFiltered, - leftExpressionSymbol, + leftExpressionVariable, rightNullsFiltered, - rightExpressionSymbol); + rightExpressionVariable); if (equalityStats.isOutputRowCountUnknown()) { return PlanNodeStatsEstimate.unknown(); } @@ -237,8 +237,8 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression( equalityFilterFactor = 0.0; } result.setOutputRowCount(inputNullsFiltered.getOutputRowCount() * (1 - equalityFilterFactor)); - leftExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics(symbol, leftNullsFiltered)); - rightExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics(symbol, rightNullsFiltered)); + leftExpressionVariable.ifPresent(symbol -> result.addVariableStatistics(symbol, leftNullsFiltered)); + rightExpressionVariable.ifPresent(symbol -> result.addVariableStatistics(symbol, rightNullsFiltered)); return result.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java index bed1fbe6cdb3a..44813669d5e1e 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java @@ -15,7 +15,7 @@ package com.facebook.presto.cost; import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -113,25 +113,25 @@ public PlanCostEstimate visitGroupReference(GroupReference node, Void context) @Override public PlanCostEstimate visitAssignUniqueId(AssignUniqueId node, Void context) { - LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(ImmutableList.of(node.getIdColumn()), types)); + LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(ImmutableList.of(node.getIdVariable()))); return costForStreaming(node, localCost); } @Override public PlanCostEstimate visitRowNumber(RowNumberNode node, Void context) { - List symbols = node.getOutputSymbols(); + List variables = node.getOutputVariables(); // when maxRowCountPerPartition is set, the RowNumberOperator // copies values for all the columns into a page builder if (!node.getMaxRowCountPerPartition().isPresent()) { - symbols = ImmutableList.builder() + variables = ImmutableList.builder() .addAll(node.getPartitionBy()) - .add(node.getRowNumberSymbol()) + .add(node.getRowNumberVariable()) .build(); } PlanNodeStatsEstimate stats = getStats(node); - double cpuCost = stats.getOutputSizeInBytes(symbols, types); - double memoryCost = node.getPartitionBy().isEmpty() ? 0 : stats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), types); + double cpuCost = stats.getOutputSizeInBytes(variables); + double memoryCost = node.getPartitionBy().isEmpty() ? 0 : stats.getOutputSizeInBytes(node.getSource().getOutputVariables()); LocalCostEstimate localCost = LocalCostEstimate.of(cpuCost, memoryCost, 0); return costForStreaming(node, localCost); } @@ -146,21 +146,21 @@ public PlanCostEstimate visitOutput(OutputNode node, Void context) public PlanCostEstimate visitTableScan(TableScanNode node, Void context) { // TODO: add network cost, based on input size in bytes? Or let connector provide this cost? - LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types)); + LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputVariables())); return costForSource(node, localCost); } @Override public PlanCostEstimate visitFilter(FilterNode node, Void context) { - LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node.getSource()).getOutputSizeInBytes(node.getOutputSymbols(), types)); + LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node.getSource()).getOutputSizeInBytes(node.getOutputVariables())); return costForStreaming(node, localCost); } @Override public PlanCostEstimate visitProject(ProjectNode node, Void context) { - LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types)); + LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputVariables())); return costForStreaming(node, localCost); } @@ -172,8 +172,8 @@ public PlanCostEstimate visitAggregation(AggregationNode node, Void context) } PlanNodeStatsEstimate aggregationStats = getStats(node); PlanNodeStatsEstimate sourceStats = getStats(node.getSource()); - double cpuCost = sourceStats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), types); - double memoryCost = aggregationStats.getOutputSizeInBytes(node.getOutputSymbols(), types); + double cpuCost = sourceStats.getOutputSizeInBytes(node.getSource().getOutputVariables()); + double memoryCost = aggregationStats.getOutputSizeInBytes(node.getOutputVariables()); LocalCostEstimate localCost = LocalCostEstimate.of(cpuCost, memoryCost, 0); return costForAccumulation(node, localCost); } @@ -205,7 +205,7 @@ private LocalCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanN private LocalCostEstimate calculateJoinOutputCost(PlanNode join) { PlanNodeStatsEstimate outputStats = getStats(join); - double joinOutputSize = outputStats.getOutputSizeInBytes(join.getOutputSymbols(), types); + double joinOutputSize = outputStats.getOutputSizeInBytes(join.getOutputVariables()); return LocalCostEstimate.ofCpu(joinOutputSize); } @@ -217,7 +217,7 @@ public PlanCostEstimate visitExchange(ExchangeNode node, Void context) private LocalCostEstimate calculateExchangeCost(ExchangeNode node) { - double inputSizeInBytes = getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types); + double inputSizeInBytes = getStats(node).getOutputSizeInBytes(node.getOutputVariables()); switch (node.getScope()) { case LOCAL: switch (node.getType()) { @@ -291,7 +291,7 @@ public PlanCostEstimate visitLimit(LimitNode node, Void context) // so proper cost estimation is not that important. Second, since LimitNode can lead to incomplete evaluation // of the source, true cost estimation should be implemented as a "constraint" enforced on a sub-tree and // evaluated in context of actual source node type (and their sources). - LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types)); + LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputVariables())); return costForStreaming(node, localCost); } @@ -307,8 +307,8 @@ public PlanCostEstimate visitUnion(UnionNode node, Void context) @Override public PlanCostEstimate visitSort(SortNode node, Void context) { - double cpuCost = getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types); - double memoryCost = getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types); + double cpuCost = getStats(node).getOutputSizeInBytes(node.getOutputVariables()); + double memoryCost = getStats(node).getOutputSizeInBytes(node.getOutputVariables()); LocalCostEstimate localCost = LocalCostEstimate.of(cpuCost, memoryCost, 0); return costForAccumulation(node, localCost); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java index 893d3ff69d76c..11815629e5dfb 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java @@ -114,7 +114,7 @@ public LocalCostEstimate visitGroupReference(GroupReference node, Void context) public LocalCostEstimate visitAggregation(AggregationNode node, Void context) { PlanNode source = node.getSource(); - double inputSizeInBytes = getStats(source).getOutputSizeInBytes(source.getOutputSymbols(), types); + double inputSizeInBytes = getStats(source).getOutputSizeInBytes(source.getOutputVariables()); LocalCostEstimate remoteRepartitionCost = calculateRemoteRepartitionCost(inputSizeInBytes); LocalCostEstimate localRepartitionCost = calculateLocalRepartitionCost(inputSizeInBytes); @@ -166,7 +166,7 @@ public LocalCostEstimate visitUnion(UnionNode node, Void context) // that is not aways true // but this estimate is better that returning UNKNOWN, as it sets // cumulative cost to unknown - double inputSizeInBytes = getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types); + double inputSizeInBytes = getStats(node).getOutputSizeInBytes(node.getOutputVariables()); return calculateRemoteGatherCost(inputSizeInBytes); } @@ -229,8 +229,8 @@ private static LocalCostEstimate calculateJoinExchangeCost( boolean replicated, int estimatedSourceDistributedTaskCount) { - double probeSizeInBytes = stats.getStats(probe).getOutputSizeInBytes(probe.getOutputSymbols(), types); - double buildSizeInBytes = stats.getStats(build).getOutputSizeInBytes(build.getOutputSymbols(), types); + double probeSizeInBytes = stats.getStats(probe).getOutputSizeInBytes(probe.getOutputVariables()); + double buildSizeInBytes = stats.getStats(build).getOutputSizeInBytes(build.getOutputVariables()); if (replicated) { // assuming the probe side of a replicated join is always source distributed LocalCostEstimate replicateCost = calculateRemoteReplicateCost(buildSizeInBytes, estimatedSourceDistributedTaskCount); @@ -259,8 +259,8 @@ public static LocalCostEstimate calculateJoinInputCost( PlanNodeStatsEstimate probeStats = stats.getStats(probe); PlanNodeStatsEstimate buildStats = stats.getStats(build); - double buildSideSize = buildStats.getOutputSizeInBytes(build.getOutputSymbols(), types); - double probeSideSize = probeStats.getOutputSizeInBytes(probe.getOutputSymbols(), types); + double buildSideSize = buildStats.getOutputSizeInBytes(build.getOutputVariables()); + double probeSideSize = probeStats.getOutputSizeInBytes(probe.getOutputVariables()); double cpuCost = probeSideSize + buildSideSize * buildSizeMultiplier; diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java index 16652962115a7..ff2845c4c17b4 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -53,7 +53,7 @@ protected Optional doCalculate(ExchangeNode node, StatsPr PlanNode source = node.getSources().get(i); PlanNodeStatsEstimate sourceStats = statsProvider.getStats(source); - PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputSymbols(sourceStats, node.getInputs().get(i), node.getOutputSymbols()); + PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputVariables(sourceStats, node.getInputs().get(i), node.getOutputVariables()); if (estimate.isPresent()) { estimate = Optional.of(addStatsAndMaxDistinctValues(estimate.get(), sourceStatsWithMappedSymbols)); @@ -67,14 +67,14 @@ protected Optional doCalculate(ExchangeNode node, StatsPr return estimate; } - private PlanNodeStatsEstimate mapToOutputSymbols(PlanNodeStatsEstimate estimate, List inputs, List outputs) + private PlanNodeStatsEstimate mapToOutputVariables(PlanNodeStatsEstimate estimate, List inputs, List outputs) { checkArgument(inputs.size() == outputs.size(), "Input symbols count does not match output symbols count"); PlanNodeStatsEstimate.Builder mapped = PlanNodeStatsEstimate.builder() .setOutputRowCount(estimate.getOutputRowCount()); for (int i = 0; i < inputs.size(); i++) { - mapped.addSymbolStatistics(outputs.get(i), estimate.getSymbolStatistics(inputs.get(i))); + mapped.addVariableStatistics(outputs.get(i), estimate.getVariableStatistics(inputs.get(i))); } return mapped.build(); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java index a05596aaa0485..c7a9a36222e16 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java @@ -195,7 +195,7 @@ private class FilterExpressionStatsCalculatingVisitor @Override public PlanNodeStatsEstimate process(Node node, @Nullable Void context) { - return normalizer.normalize(super.process(node, context), types); + return normalizer.normalize(super.process(node, context)); } @Override @@ -289,7 +289,7 @@ protected PlanNodeStatsEstimate visitBooleanLiteral(BooleanLiteral node, Void co PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); result.setOutputRowCount(0.0); - input.getSymbolsWithKnownStatistics().forEach(symbol -> result.addSymbolStatistics(symbol, SymbolStatsEstimate.zero())); + input.getVariablesWithKnownStatistics().forEach(variable -> result.addVariableStatistics(variable, VariableStatsEstimate.zero())); return result.build(); } @@ -297,11 +297,11 @@ protected PlanNodeStatsEstimate visitBooleanLiteral(BooleanLiteral node, Void co protected PlanNodeStatsEstimate visitIsNotNullPredicate(IsNotNullPredicate node, Void context) { if (node.getValue() instanceof SymbolReference) { - Symbol symbol = Symbol.from(node.getValue()); - SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol); + VariableReferenceExpression variable = toVariable(Symbol.from(node.getValue())); + VariableStatsEstimate variableStats = input.getVariableStatistics(variable); PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input); - result.setOutputRowCount(input.getOutputRowCount() * (1 - symbolStats.getNullsFraction())); - result.addSymbolStatistics(symbol, symbolStats.mapNullsFraction(x -> 0.0)); + result.setOutputRowCount(input.getOutputRowCount() * (1 - variableStats.getNullsFraction())); + result.addVariableStatistics(variable, variableStats.mapNullsFraction(x -> 0.0)); return result.build(); } return PlanNodeStatsEstimate.unknown(); @@ -311,11 +311,11 @@ protected PlanNodeStatsEstimate visitIsNotNullPredicate(IsNotNullPredicate node, protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void context) { if (node.getValue() instanceof SymbolReference) { - Symbol symbol = Symbol.from(node.getValue()); - SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol); + VariableReferenceExpression variable = toVariable(Symbol.from(node.getValue())); + VariableStatsEstimate variableStats = input.getVariableStatistics(variable); PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input); - result.setOutputRowCount(input.getOutputRowCount() * symbolStats.getNullsFraction()); - result.addSymbolStatistics(symbol, SymbolStatsEstimate.builder() + result.setOutputRowCount(input.getOutputRowCount() * variableStats.getNullsFraction()); + result.addVariableStatistics(variable, VariableStatsEstimate.builder() .setNullsFraction(1.0) .setLowValue(NaN) .setHighValue(NaN) @@ -339,7 +339,7 @@ protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Voi return PlanNodeStatsEstimate.unknown(); } - SymbolStatsEstimate valueStats = input.getSymbolStatistics(Symbol.from(node.getValue())); + VariableStatsEstimate valueStats = input.getVariableStatistics(toVariable(Symbol.from(node.getValue()))); Expression lowerBound = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()); Expression upperBound = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax()); @@ -379,7 +379,7 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) return PlanNodeStatsEstimate.unknown(); } - SymbolStatsEstimate valueStats = getExpressionStats(node.getValue()); + VariableStatsEstimate valueStats = getExpressionStats(node.getValue()); if (valueStats.isUnknown()) { return PlanNodeStatsEstimate.unknown(); } @@ -390,10 +390,10 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) result.setOutputRowCount(min(inEstimate.getOutputRowCount(), notNullValuesBeforeIn)); if (node.getValue() instanceof SymbolReference) { - Symbol valueSymbol = Symbol.from(node.getValue()); - SymbolStatsEstimate newSymbolStats = inEstimate.getSymbolStatistics(valueSymbol) + VariableReferenceExpression valueVariable = toVariable(Symbol.from(node.getValue())); + VariableStatsEstimate newvariableStats = inEstimate.getVariableStatistics(valueVariable) .mapDistinctValuesCount(newDistinctValuesCount -> min(newDistinctValuesCount, valueStats.getDistinctValuesCount())); - result.addSymbolStatistics(valueSymbol, newSymbolStats); + result.addVariableStatistics(valueVariable, newvariableStats); } return result.build(); } @@ -421,25 +421,25 @@ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression n return process(new IsNotNullPredicate(left)); } - SymbolStatsEstimate leftStats = getExpressionStats(left); - Optional leftSymbol = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty(); + VariableStatsEstimate leftStats = getExpressionStats(left); + Optional leftVariable = left instanceof SymbolReference ? Optional.of(toVariable(Symbol.from(left))) : Optional.empty(); if (right instanceof Literal) { Object literalValue = LiteralInterpreter.evaluate(metadata, session.toConnectorSession(), right); if (literalValue == null) { return visitBooleanLiteral(FALSE_LITERAL, null); } OptionalDouble literal = toStatsRepresentation(metadata, session, getType(left), literalValue); - return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, literal, operator); + return estimateExpressionToLiteralComparison(input, leftStats, leftVariable, literal, operator); } - SymbolStatsEstimate rightStats = getExpressionStats(right); + VariableStatsEstimate rightStats = getExpressionStats(right); if (rightStats.isSingleValue()) { OptionalDouble value = isNaN(rightStats.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(rightStats.getLowValue()); - return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, value, operator); + return estimateExpressionToLiteralComparison(input, leftStats, leftVariable, value, operator); } - Optional rightSymbol = right instanceof SymbolReference ? Optional.of(Symbol.from(right)) : Optional.empty(); - return estimateExpressionToExpressionComparison(input, leftStats, leftSymbol, rightStats, rightSymbol, operator); + Optional rightVariable = right instanceof SymbolReference ? Optional.of(toVariable(Symbol.from(right))) : Optional.empty(); + return estimateExpressionToExpressionComparison(input, leftStats, leftVariable, rightStats, rightVariable, operator); } private Type getType(Expression expression) @@ -462,14 +462,19 @@ private Type getType(Expression expression) return expressionAnalyzer.analyze(expression, Scope.create()); } - private SymbolStatsEstimate getExpressionStats(Expression expression) + private VariableStatsEstimate getExpressionStats(Expression expression) { if (expression instanceof SymbolReference) { - Symbol symbol = Symbol.from(expression); - return requireNonNull(input.getSymbolStatistics(symbol), () -> format("No statistics for symbol %s", symbol)); + VariableReferenceExpression variable = toVariable(Symbol.from(expression)); + return requireNonNull(input.getVariableStatistics(variable), () -> format("No statistics for variable %s", variable)); } return scalarStatsCalculator.calculate(expression, input, session, types); } + + private VariableReferenceExpression toVariable(Symbol symbol) + { + return new VariableReferenceExpression(symbol.getName(), types.get(symbol)); + } } private class FilterRowExpressionStatsCalculatingVisitor @@ -518,7 +523,7 @@ public PlanNodeStatsEstimate visitConstant(ConstantExpression node, Void context } PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); result.setOutputRowCount(0.0); - input.getSymbolsWithKnownStatistics().forEach(symbol -> result.addSymbolStatistics(symbol, SymbolStatsEstimate.zero())); + input.getVariablesWithKnownStatistics().forEach(variable -> result.addVariableStatistics(variable, VariableStatsEstimate.zero())); return result.build(); } return PlanNodeStatsEstimate.unknown(); @@ -549,7 +554,7 @@ public PlanNodeStatsEstimate visitCall(CallExpression node, Void context) checkArgument(!(left instanceof ConstantExpression && right instanceof ConstantExpression), "Literal-to-literal not supported here, should be eliminated earlier"); if (!(left instanceof VariableReferenceExpression) && right instanceof VariableReferenceExpression) { - // normalize so that symbol is on the left + // normalize so that variable is on the left OperatorType flippedOperator = flip(operatorType); return process(call(flippedOperator.name(), metadata.getFunctionManager().resolveOperator(flippedOperator, fromTypes(right.getType(), left.getType())), BOOLEAN, right, left)); } @@ -564,25 +569,25 @@ public PlanNodeStatsEstimate visitCall(CallExpression node, Void context) return process(not(isNull(left))); } - SymbolStatsEstimate leftStats = getRowExpressionStats(left); - Optional leftSymbol = left instanceof VariableReferenceExpression ? Optional.of(new Symbol(((VariableReferenceExpression) left).getName())) : Optional.empty(); + VariableStatsEstimate leftStats = getRowExpressionStats(left); + Optional leftVariable = left instanceof VariableReferenceExpression ? Optional.of((VariableReferenceExpression) left) : Optional.empty(); if (right instanceof ConstantExpression) { Object rightValue = ((ConstantExpression) right).getValue(); if (rightValue == null) { return visitConstant(constantNull(BOOLEAN), null); } OptionalDouble literal = toStatsRepresentation(metadata, session, right.getType(), rightValue); - return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, literal, getComparisonOperator(operatorType)); + return estimateExpressionToLiteralComparison(input, leftStats, leftVariable, literal, getComparisonOperator(operatorType)); } - SymbolStatsEstimate rightStats = getRowExpressionStats(right); + VariableStatsEstimate rightStats = getRowExpressionStats(right); if (rightStats.isSingleValue()) { OptionalDouble value = isNaN(rightStats.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(rightStats.getLowValue()); - return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, value, getComparisonOperator(operatorType)); + return estimateExpressionToLiteralComparison(input, leftStats, leftVariable, value, getComparisonOperator(operatorType)); } - Optional rightSymbol = right instanceof VariableReferenceExpression ? Optional.of(new Symbol(((VariableReferenceExpression) right).getName())) : Optional.empty(); - return estimateExpressionToExpressionComparison(input, leftStats, leftSymbol, rightStats, rightSymbol, getComparisonOperator(operatorType)); + Optional rightVariable = right instanceof VariableReferenceExpression ? Optional.of((VariableReferenceExpression) right) : Optional.empty(); + return estimateExpressionToExpressionComparison(input, leftStats, leftVariable, rightStats, rightVariable, getComparisonOperator(operatorType)); } // NOT case @@ -592,11 +597,11 @@ public PlanNodeStatsEstimate visitCall(CallExpression node, Void context) // IS NOT NULL case RowExpression innerArugment = ((SpecialFormExpression) arguemnt).getArguments().get(0); if (innerArugment instanceof VariableReferenceExpression) { - Symbol symbol = new Symbol(((VariableReferenceExpression) innerArugment).getName()); - SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol); + VariableReferenceExpression variable = (VariableReferenceExpression) innerArugment; + VariableStatsEstimate variableStats = input.getVariableStatistics(variable); PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input); - result.setOutputRowCount(input.getOutputRowCount() * (1 - symbolStats.getNullsFraction())); - result.addSymbolStatistics(symbol, symbolStats.mapNullsFraction(x -> 0.0)); + result.setOutputRowCount(input.getOutputRowCount() * (1 - variableStats.getNullsFraction())); + result.addVariableStatistics(variable, variableStats.mapNullsFraction(x -> 0.0)); return result.build(); } return PlanNodeStatsEstimate.unknown(); @@ -619,7 +624,7 @@ public PlanNodeStatsEstimate visitCall(CallExpression node, Void context) return PlanNodeStatsEstimate.unknown(); } - SymbolStatsEstimate valueStats = input.getSymbolStatistics(new Symbol(((VariableReferenceExpression) value).getName())); + VariableStatsEstimate valueStats = input.getVariableStatistics((VariableReferenceExpression) value); RowExpression lowerBound = call( OperatorType.GREATER_THAN_OR_EQUAL.name(), metadata.getFunctionManager().resolveOperator(OperatorType.GREATER_THAN_OR_EQUAL, fromTypes(value.getType(), min.getType())), @@ -661,7 +666,7 @@ private FilterRowExpressionStatsCalculatingVisitor newEstimate(PlanNodeStatsEsti private PlanNodeStatsEstimate process(RowExpression rowExpression) { - return normalizer.normalize(rowExpression.accept(this, null), types); + return normalizer.normalize(rowExpression.accept(this, null)); } private PlanNodeStatsEstimate estimateLogicalAnd(RowExpression left, RowExpression right) @@ -736,7 +741,7 @@ private PlanNodeStatsEstimate estimateIn(RowExpression value, List min(newDistinctValuesCount, valueStats.getDistinctValuesCount())); - result.addSymbolStatistics(valueSymbol, newSymbolStats); + result.addVariableStatistics(valueVariable, newVariableStats); } return result.build(); } @@ -758,11 +763,11 @@ private PlanNodeStatsEstimate estimateIn(RowExpression value, List format("No statistics for symbol %s", symbol)); + VariableReferenceExpression variable = (VariableReferenceExpression) expression; + return requireNonNull(input.getVariableStatistics(variable), () -> format("No statistics for variable %s", variable)); } return scalarStatsCalculator.calculate(expression, input, session); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java index 747baaa1b5728..3decf7bec1378 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java @@ -17,12 +17,13 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.relation.LogicalRowExpressions; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.util.MoreMath; import com.google.common.annotations.VisibleForTesting; @@ -33,7 +34,7 @@ import java.util.Queue; import static com.facebook.presto.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; -import static com.facebook.presto.cost.SymbolStatsEstimate.buildFrom; +import static com.facebook.presto.cost.VariableStatsEstimate.buildFrom; import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts; import static com.facebook.presto.sql.planner.plan.Patterns.join; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; @@ -182,7 +183,7 @@ private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode node, PlanNodeStats } if (filteredEquiJoinEstimate.isOutputRowCountUnknown()) { - return normalizer.normalize(equiJoinEstimate.mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT), types); + return normalizer.normalize(equiJoinEstimate.mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT)); } return filteredEquiJoinEstimate; @@ -221,7 +222,7 @@ private PlanNodeStatsEstimate filterByEquiJoinClauses( Session session, TypeProvider types) { - ComparisonExpression drivingPredicate = new ComparisonExpression(EQUAL, drivingClause.getLeft().toSymbolReference(), drivingClause.getRight().toSymbolReference()); + ComparisonExpression drivingPredicate = new ComparisonExpression(EQUAL, new SymbolReference(drivingClause.getLeft().getName()), new SymbolReference(drivingClause.getRight().getName())); PlanNodeStatsEstimate filteredStats = filterStatsCalculator.filterStats(stats, drivingPredicate, session, types); for (EquiJoinClause clause : remainingClauses) { filteredStats = filterByAuxiliaryClause(filteredStats, clause, types); @@ -234,8 +235,8 @@ private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate stat // we just clear null fraction and adjust ranges here // selectivity is mostly handled by driving clause. We just scale heuristically by UNKNOWN_FILTER_COEFFICIENT here. - SymbolStatsEstimate leftStats = stats.getSymbolStatistics(clause.getLeft()); - SymbolStatsEstimate rightStats = stats.getSymbolStatistics(clause.getRight()); + VariableStatsEstimate leftStats = stats.getVariableStatistics(clause.getLeft()); + VariableStatsEstimate rightStats = stats.getVariableStatistics(clause.getRight()); StatisticRange leftRange = StatisticRange.from(leftStats); StatisticRange rightRange = StatisticRange.from(rightStats); @@ -246,13 +247,13 @@ private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate stat double rightNdvInRange = rightFilterValue * rightRange.getDistinctValuesCount(); double retainedNdv = MoreMath.min(leftNdvInRange, rightNdvInRange); - SymbolStatsEstimate newLeftStats = buildFrom(leftStats) + VariableStatsEstimate newLeftStats = buildFrom(leftStats) .setNullsFraction(0) .setStatisticsRange(intersect) .setDistinctValuesCount(retainedNdv) .build(); - SymbolStatsEstimate newRightStats = buildFrom(rightStats) + VariableStatsEstimate newRightStats = buildFrom(rightStats) .setNullsFraction(0) .setStatisticsRange(intersect) .setDistinctValuesCount(retainedNdv) @@ -260,9 +261,9 @@ private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate stat PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(stats) .setOutputRowCount(stats.getOutputRowCount() * UNKNOWN_FILTER_COEFFICIENT) - .addSymbolStatistics(clause.getLeft(), newLeftStats) - .addSymbolStatistics(clause.getRight(), newRightStats); - return normalizer.normalize(result.build(), types); + .addVariableStatistics(clause.getLeft(), newLeftStats) + .addVariableStatistics(clause.getRight(), newRightStats); + return normalizer.normalize(result.build()); } private static double firstNonNaN(double... values) @@ -297,7 +298,7 @@ PlanNodeStatsEstimate calculateJoinComplementStats( return PlanNodeStatsEstimate.unknown(); } - return normalizer.normalize(leftStats.mapOutputRowCount(rowCount -> 0.0), types); + return normalizer.normalize(leftStats.mapOutputRowCount(rowCount -> 0.0)); } // TODO: add support for non-equality conditions (e.g: <=, !=, >) @@ -321,7 +322,7 @@ PlanNodeStatsEstimate calculateJoinComplementStats( .map(drivingClause -> calculateJoinComplementStats(leftStats, rightStats, drivingClause, criteria.size() - 1 + numberOfFilterClauses)) .filter(estimate -> !estimate.isOutputRowCountUnknown()) .max(comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)) - .map(estimate -> normalizer.normalize(estimate, types)) + .map(estimate -> normalizer.normalize(estimate)) .orElse(PlanNodeStatsEstimate.unknown()); } @@ -333,8 +334,8 @@ private PlanNodeStatsEstimate calculateJoinComplementStats( { PlanNodeStatsEstimate result = leftStats; - SymbolStatsEstimate leftColumnStats = leftStats.getSymbolStatistics(drivingClause.getLeft()); - SymbolStatsEstimate rightColumnStats = rightStats.getSymbolStatistics(drivingClause.getRight()); + VariableStatsEstimate leftColumnStats = leftStats.getVariableStatistics(drivingClause.getLeft()); + VariableStatsEstimate rightColumnStats = rightStats.getVariableStatistics(drivingClause.getRight()); // TODO: use range methods when they have defined (and consistent) semantics double leftNDV = leftColumnStats.getDistinctValuesCount(); @@ -345,8 +346,8 @@ private PlanNodeStatsEstimate calculateJoinComplementStats( double nonMatchingLeftValuesFraction = leftColumnStats.getValuesFraction() * (leftNDV - matchingRightNDV) / leftNDV; double scaleFactor = nonMatchingLeftValuesFraction + leftColumnStats.getNullsFraction(); double newLeftNullsFraction = leftColumnStats.getNullsFraction() / scaleFactor; - result = result.mapSymbolColumnStatistics(drivingClause.getLeft(), columnStats -> - SymbolStatsEstimate.buildFrom(columnStats) + result = result.mapVariableColumnStatistics(drivingClause.getLeft(), columnStats -> + VariableStatsEstimate.buildFrom(columnStats) .setLowValue(leftColumnStats.getLowValue()) .setHighValue(leftColumnStats.getHighValue()) .setNullsFraction(newLeftNullsFraction) @@ -356,8 +357,8 @@ private PlanNodeStatsEstimate calculateJoinComplementStats( } else if (leftNDV <= matchingRightNDV) { // Assume all non-null left rows are matched. Therefore only null left rows are unmatched. - result = result.mapSymbolColumnStatistics(drivingClause.getLeft(), columnStats -> - SymbolStatsEstimate.buildFrom(columnStats) + result = result.mapVariableColumnStatistics(drivingClause.getLeft(), columnStats -> + VariableStatsEstimate.buildFrom(columnStats) .setLowValue(NaN) .setHighValue(NaN) .setNullsFraction(1.0) @@ -393,14 +394,14 @@ PlanNodeStatsEstimate addJoinComplementStats( PlanNodeStatsEstimate.Builder outputStats = PlanNodeStatsEstimate.buildFrom(innerJoinStats); outputStats.setOutputRowCount(outputRowCount); - for (Symbol symbol : joinComplementStats.getSymbolsWithKnownStatistics()) { - SymbolStatsEstimate leftSymbolStats = sourceStats.getSymbolStatistics(symbol); - SymbolStatsEstimate innerJoinSymbolStats = innerJoinStats.getSymbolStatistics(symbol); - SymbolStatsEstimate joinComplementSymbolStats = joinComplementStats.getSymbolStatistics(symbol); + for (VariableReferenceExpression variable : joinComplementStats.getVariablesWithKnownStatistics()) { + VariableStatsEstimate leftSymbolStats = sourceStats.getVariableStatistics(variable); + VariableStatsEstimate innerJoinSymbolStats = innerJoinStats.getVariableStatistics(variable); + VariableStatsEstimate joinComplementSymbolStats = joinComplementStats.getVariableStatistics(variable); // weighted average double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementSymbolStats.getNullsFraction() * joinComplementRowCount) / outputRowCount; - outputStats.addSymbolStatistics(symbol, SymbolStatsEstimate.buildFrom(innerJoinSymbolStats) + outputStats.addVariableStatistics(variable, VariableStatsEstimate.buildFrom(innerJoinSymbolStats) // in outer join low value, high value and NDVs of outer side columns are preserved .setLowValue(leftSymbolStats.getLowValue()) .setHighValue(leftSymbolStats.getHighValue()) @@ -410,10 +411,10 @@ PlanNodeStatsEstimate addJoinComplementStats( } // add nulls to columns that don't exist in right stats - for (Symbol symbol : difference(innerJoinStats.getSymbolsWithKnownStatistics(), joinComplementStats.getSymbolsWithKnownStatistics())) { - SymbolStatsEstimate innerJoinSymbolStats = innerJoinStats.getSymbolStatistics(symbol); + for (VariableReferenceExpression variable : difference(innerJoinStats.getVariablesWithKnownStatistics(), joinComplementStats.getVariablesWithKnownStatistics())) { + VariableStatsEstimate innerJoinSymbolStats = innerJoinStats.getVariableStatistics(variable); double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementRowCount) / outputRowCount; - outputStats.addSymbolStatistics(symbol, innerJoinSymbolStats.mapNullsFraction(nullsFraction -> newNullsFraction)); + outputStats.addVariableStatistics(variable, innerJoinSymbolStats.mapNullsFraction(nullsFraction -> newNullsFraction)); } return outputStats.build(); @@ -424,10 +425,10 @@ private PlanNodeStatsEstimate crossJoinStats(JoinNode node, PlanNodeStatsEstimat PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder() .setOutputRowCount(leftStats.getOutputRowCount() * rightStats.getOutputRowCount()); - node.getLeft().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics(symbol, leftStats.getSymbolStatistics(symbol))); - node.getRight().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics(symbol, rightStats.getSymbolStatistics(symbol))); + node.getLeft().getOutputVariables().forEach(variable -> builder.addVariableStatistics(variable, leftStats.getVariableStatistics(variable))); + node.getRight().getOutputVariables().forEach(variable -> builder.addVariableStatistics(variable, rightStats.getVariableStatistics(variable))); - return normalizer.normalize(builder.build(), types); + return normalizer.normalize(builder.build()); } private List flippedCriteria(JoinNode node) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java index 449c34a304418..91df9ae9582e5 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java @@ -13,11 +13,10 @@ */ package com.facebook.presto.cost; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.FixedWidthType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VariableWidthType; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TypeProvider; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableMap; @@ -43,7 +42,7 @@ public class PlanNodeStatsEstimate private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, ImmutableMap.of()); private final double outputRowCount; - private final PMap symbolStatistics; + private final PMap variableStatistics; public static PlanNodeStatsEstimate unknown() { @@ -53,16 +52,16 @@ public static PlanNodeStatsEstimate unknown() @JsonCreator public PlanNodeStatsEstimate( @JsonProperty("outputRowCount") double outputRowCount, - @JsonProperty("symbolStatistics") Map symbolStatistics) + @JsonProperty("variableStatistics") Map variableStatistics) { - this(outputRowCount, HashTreePMap.from(requireNonNull(symbolStatistics, "symbolStatistics is null"))); + this(outputRowCount, HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null"))); } - private PlanNodeStatsEstimate(double outputRowCount, PMap symbolStatistics) + private PlanNodeStatsEstimate(double outputRowCount, PMap variableStatistics) { checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative"); this.outputRowCount = outputRowCount; - this.symbolStatistics = symbolStatistics; + this.variableStatistics = variableStatistics; } /** @@ -79,21 +78,21 @@ public double getOutputRowCount() * Returns estimated data size. * Unknown value is represented by {@link Double#NaN} */ - public double getOutputSizeInBytes(Collection outputSymbols, TypeProvider types) + public double getOutputSizeInBytes(Collection outputVariables) { - requireNonNull(outputSymbols, "outputSymbols is null"); + requireNonNull(outputVariables, "outputSymbols is null"); - return outputSymbols.stream() - .mapToDouble(symbol -> getOutputSizeForSymbol(getSymbolStatistics(symbol), types.get(symbol))) + return outputVariables.stream() + .mapToDouble(variable -> getOutputSizeForVariable(getVariableStatistics(variable), variable.getType())) .sum(); } - private double getOutputSizeForSymbol(SymbolStatsEstimate symbolStatistics, Type type) + private double getOutputSizeForVariable(VariableStatsEstimate variableStatistics, Type type) { checkArgument(type != null, "type is null"); - double averageRowSize = symbolStatistics.getAverageRowSize(); - double nullsFraction = firstNonNaN(symbolStatistics.getNullsFraction(), 0d); + double averageRowSize = variableStatistics.getAverageRowSize(); + double nullsFraction = firstNonNaN(variableStatistics.getNullsFraction(), 0d); double numberOfNonNullRows = outputRowCount * (1.0 - nullsFraction); if (isNaN(averageRowSize)) { @@ -123,27 +122,27 @@ public PlanNodeStatsEstimate mapOutputRowCount(Function mappingF return buildFrom(this).setOutputRowCount(mappingFunction.apply(outputRowCount)).build(); } - public PlanNodeStatsEstimate mapSymbolColumnStatistics(Symbol symbol, Function mappingFunction) + public PlanNodeStatsEstimate mapVariableColumnStatistics(VariableReferenceExpression variable, Function mappingFunction) { return buildFrom(this) - .addSymbolStatistics(symbol, mappingFunction.apply(getSymbolStatistics(symbol))) + .addVariableStatistics(variable, mappingFunction.apply(getVariableStatistics(variable))) .build(); } - public SymbolStatsEstimate getSymbolStatistics(Symbol symbol) + public VariableStatsEstimate getVariableStatistics(VariableReferenceExpression variable) { - return symbolStatistics.getOrDefault(symbol, SymbolStatsEstimate.unknown()); + return variableStatistics.getOrDefault(variable, VariableStatsEstimate.unknown()); } @JsonProperty - public Map getSymbolStatistics() + public Map getVariableStatistics() { - return symbolStatistics; + return variableStatistics; } - public Set getSymbolsWithKnownStatistics() + public Set getVariablesWithKnownStatistics() { - return symbolStatistics.keySet(); + return variableStatistics.keySet(); } public boolean isOutputRowCountUnknown() @@ -156,7 +155,7 @@ public String toString() { return toStringHelper(this) .add("outputRowCount", outputRowCount) - .add("symbolStatistics", symbolStatistics) + .add("variableStatistics", variableStatistics) .toString(); } @@ -171,13 +170,13 @@ public boolean equals(Object o) } PlanNodeStatsEstimate that = (PlanNodeStatsEstimate) o; return Double.compare(outputRowCount, that.outputRowCount) == 0 && - Objects.equals(symbolStatistics, that.symbolStatistics); + Objects.equals(variableStatistics, that.variableStatistics); } @Override public int hashCode() { - return Objects.hash(outputRowCount, symbolStatistics); + return Objects.hash(outputRowCount, variableStatistics); } public static Builder builder() @@ -187,23 +186,23 @@ public static Builder builder() public static Builder buildFrom(PlanNodeStatsEstimate other) { - return new Builder(other.getOutputRowCount(), other.symbolStatistics); + return new Builder(other.getOutputRowCount(), other.variableStatistics); } public static final class Builder { private double outputRowCount; - private PMap symbolStatistics; + private PMap variableStatistics; public Builder() { this(NaN, HashTreePMap.empty()); } - private Builder(double outputRowCount, PMap symbolStatistics) + private Builder(double outputRowCount, PMap variableStatistics) { this.outputRowCount = outputRowCount; - this.symbolStatistics = symbolStatistics; + this.variableStatistics = variableStatistics; } public Builder setOutputRowCount(double outputRowCount) @@ -212,27 +211,27 @@ public Builder setOutputRowCount(double outputRowCount) return this; } - public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics) + public Builder addVariableStatistics(VariableReferenceExpression variable, VariableStatsEstimate statistics) { - symbolStatistics = symbolStatistics.plus(symbol, statistics); + variableStatistics = variableStatistics.plus(variable, statistics); return this; } - public Builder addSymbolStatistics(Map symbolStatistics) + public Builder addVariableStatistics(Map variableStatistics) { - this.symbolStatistics = this.symbolStatistics.plusAll(symbolStatistics); + this.variableStatistics = this.variableStatistics.plusAll(variableStatistics); return this; } - public Builder removeSymbolStatistics(Symbol symbol) + public Builder removeVariableStatistics(VariableReferenceExpression variable) { - symbolStatistics = symbolStatistics.minus(symbol); + variableStatistics = variableStatistics.minus(variable); return this; } public PlanNodeStatsEstimate build() { - return new PlanNodeStatsEstimate(outputRowCount, symbolStatistics); + return new PlanNodeStatsEstimate(outputRowCount, variableStatistics); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java index e2804fcb40882..4af6fef590d31 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java @@ -47,11 +47,11 @@ public static PlanNodeStatsEstimate subtractSubsetStats(PlanNodeStatsEstimate su PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); result.setOutputRowCount(outputRowCount); - superset.getSymbolsWithKnownStatistics().forEach(symbol -> { - SymbolStatsEstimate supersetSymbolStats = superset.getSymbolStatistics(symbol); - SymbolStatsEstimate subsetSymbolStats = subset.getSymbolStatistics(symbol); + superset.getVariablesWithKnownStatistics().forEach(symbol -> { + VariableStatsEstimate supersetSymbolStats = superset.getVariableStatistics(symbol); + VariableStatsEstimate subsetSymbolStats = subset.getVariableStatistics(symbol); - SymbolStatsEstimate.Builder newSymbolStats = SymbolStatsEstimate.builder(); + VariableStatsEstimate.Builder newSymbolStats = VariableStatsEstimate.builder(); // for simplicity keep the average row size the same as in the input // in most cases the average row size doesn't change after applying filters @@ -94,7 +94,7 @@ else if (subsetDistinctValues == 0) { newSymbolStats.setLowValue(supersetSymbolStats.getLowValue()); newSymbolStats.setHighValue(supersetSymbolStats.getHighValue()); - result.addSymbolStatistics(symbol, newSymbolStats.build()); + result.addVariableStatistics(symbol, newSymbolStats.build()); }); return result.build(); @@ -110,11 +110,11 @@ public static PlanNodeStatsEstimate capStats(PlanNodeStatsEstimate stats, PlanNo double cappedRowCount = min(stats.getOutputRowCount(), cap.getOutputRowCount()); result.setOutputRowCount(cappedRowCount); - stats.getSymbolsWithKnownStatistics().forEach(symbol -> { - SymbolStatsEstimate symbolStats = stats.getSymbolStatistics(symbol); - SymbolStatsEstimate capSymbolStats = cap.getSymbolStatistics(symbol); + stats.getVariablesWithKnownStatistics().forEach(symbol -> { + VariableStatsEstimate symbolStats = stats.getVariableStatistics(symbol); + VariableStatsEstimate capSymbolStats = cap.getVariableStatistics(symbol); - SymbolStatsEstimate.Builder newSymbolStats = SymbolStatsEstimate.builder(); + VariableStatsEstimate.Builder newSymbolStats = VariableStatsEstimate.builder(); // for simplicity keep the average row size the same as in the input // in most cases the average row size doesn't change after applying filters @@ -130,7 +130,7 @@ public static PlanNodeStatsEstimate capStats(PlanNodeStatsEstimate stats, PlanNo double cappedNullsFraction = cappedRowCount == 0 ? 1 : cappedNumberOfNulls / cappedRowCount; newSymbolStats.setNullsFraction(cappedNullsFraction); - result.addSymbolStatistics(symbol, newSymbolStats.build()); + result.addVariableStatistics(symbol, newSymbolStats.build()); }); return result.build(); @@ -140,7 +140,7 @@ private static PlanNodeStatsEstimate createZeroStats(PlanNodeStatsEstimate stats { PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); result.setOutputRowCount(0); - stats.getSymbolsWithKnownStatistics().forEach(symbol -> result.addSymbolStatistics(symbol, SymbolStatsEstimate.zero())); + stats.getVariablesWithKnownStatistics().forEach(symbol -> result.addVariableStatistics(symbol, VariableStatsEstimate.zero())); return result.build(); } @@ -174,26 +174,26 @@ private static PlanNodeStatsEstimate addStats(PlanNodeStatsEstimate left, PlanNo PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); double newRowCount = left.getOutputRowCount() + right.getOutputRowCount(); - concat(left.getSymbolsWithKnownStatistics().stream(), right.getSymbolsWithKnownStatistics().stream()) + concat(left.getVariablesWithKnownStatistics().stream(), right.getVariablesWithKnownStatistics().stream()) .distinct() .forEach(symbol -> { - SymbolStatsEstimate symbolStats = SymbolStatsEstimate.zero(); + VariableStatsEstimate symbolStats = VariableStatsEstimate.zero(); if (newRowCount > 0) { symbolStats = addColumnStats( - left.getSymbolStatistics(symbol), + left.getVariableStatistics(symbol), left.getOutputRowCount(), - right.getSymbolStatistics(symbol), + right.getVariableStatistics(symbol), right.getOutputRowCount(), newRowCount, strategy); } - statsBuilder.addSymbolStatistics(symbol, symbolStats); + statsBuilder.addVariableStatistics(symbol, symbolStats); }); return statsBuilder.setOutputRowCount(newRowCount).build(); } - private static SymbolStatsEstimate addColumnStats(SymbolStatsEstimate leftStats, double leftRows, SymbolStatsEstimate rightStats, double rightRows, double newRowCount, RangeAdditionStrategy strategy) + private static VariableStatsEstimate addColumnStats(VariableStatsEstimate leftStats, double leftRows, VariableStatsEstimate rightStats, double rightRows, double newRowCount, RangeAdditionStrategy strategy) { checkArgument(newRowCount > 0, "newRowCount must be greater than zero"); @@ -211,7 +211,7 @@ private static SymbolStatsEstimate addColumnStats(SymbolStatsEstimate leftStats, // FIXME, weights to average. left and right should be equal in most cases anyway double newAverageRowSize = newNonNullsRowCount == 0 ? 0 : ((totalSizeLeft + totalSizeRight) / newNonNullsRowCount); - return SymbolStatsEstimate.builder() + return VariableStatsEstimate.builder() .setStatisticsRange(sum) .setAverageRowSize(newAverageRowSize) .setNullsFraction(newNullsFraction) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java index cb6ad1e739ba4..afc3824374b52 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -53,8 +53,8 @@ protected Optional doCalculate(ProjectNode node, StatsPro PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder() .setOutputRowCount(sourceStats.getOutputRowCount()); - for (Map.Entry entry : node.getAssignments().entrySet()) { - calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types)); + for (Map.Entry entry : node.getAssignments().entrySet()) { + calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types)); } return Optional.of(calculatedStats.build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/RowNumberStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/RowNumberStatsRule.java index 0c4e10d135dfb..0cccf2a31e36b 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/RowNumberStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/RowNumberStatsRule.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.Patterns; @@ -53,11 +53,11 @@ public Optional doCalculate(RowNumberNode node, StatsProv double sourceRowsCount = sourceStats.getOutputRowCount(); double partitionCount = 1; - for (Symbol groupBySymbol : node.getPartitionBy()) { - SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol); - int nullRow = (symbolStatistics.getNullsFraction() == 0.0) ? 0 : 1; + for (VariableReferenceExpression groupByVariable : node.getPartitionBy()) { + VariableStatsEstimate variableStatistics = sourceStats.getVariableStatistics(groupByVariable); + int nullRow = (variableStatistics.getNullsFraction() == 0.0) ? 0 : 1; // assuming no correlation between grouping keys - partitionCount *= symbolStatistics.getDistinctValuesCount() + nullRow; + partitionCount *= variableStatistics.getDistinctValuesCount() + nullRow; } partitionCount = min(sourceRowsCount, partitionCount); @@ -78,7 +78,7 @@ public Optional doCalculate(RowNumberNode node, StatsProv return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats) .setOutputRowCount(outputRowsCount) - .addSymbolStatistics(node.getRowNumberSymbol(), SymbolStatsEstimate.builder() + .addVariableStatistics(node.getRowNumberVariable(), VariableStatsEstimate.builder() // Note: if we assume no skew, we could also estimate highValue // (as rowsPerPartition), but underestimation of highValue may have // more severe consequences than underestimation of distinctValuesCount diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java index 07ef1fe62b322..bd8a071c87772 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java @@ -83,18 +83,18 @@ public ScalarStatsCalculator(Metadata metadata) } @Deprecated - public SymbolStatsEstimate calculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session, TypeProvider types) + public VariableStatsEstimate calculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session, TypeProvider types) { return new ExpressionStatsVisitor(inputStatistics, session, types).process(scalarExpression); } - public SymbolStatsEstimate calculate(RowExpression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session) + public VariableStatsEstimate calculate(RowExpression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session) { return scalarExpression.accept(new RowExpressionStatsVisitor(inputStatistics, session), null); } private class RowExpressionStatsVisitor - implements RowExpressionVisitor + implements RowExpressionVisitor { private final PlanNodeStatsEstimate input; private final Session session; @@ -107,7 +107,7 @@ public RowExpressionStatsVisitor(PlanNodeStatsEstimate input, Session session) } @Override - public SymbolStatsEstimate visitCall(CallExpression call, Void context) + public VariableStatsEstimate visitCall(CallExpression call, Void context) { if (resolution.isCastFunction(call.getFunctionHandle())) { return computeCastStatistics(call, context); @@ -130,31 +130,31 @@ public SymbolStatsEstimate visitCall(CallExpression call, Void context) if (value instanceof RowExpression) { // value is not a constant - return SymbolStatsEstimate.unknown(); + return VariableStatsEstimate.unknown(); } // value is a constant - return SymbolStatsEstimate.builder() + return VariableStatsEstimate.builder() .setNullsFraction(0) .setDistinctValuesCount(1) .build(); } @Override - public SymbolStatsEstimate visitInputReference(InputReferenceExpression reference, Void context) + public VariableStatsEstimate visitInputReference(InputReferenceExpression reference, Void context) { throw new UnsupportedOperationException("symbol stats estimation should not reach channel mapping"); } @Override - public SymbolStatsEstimate visitConstant(ConstantExpression literal, Void context) + public VariableStatsEstimate visitConstant(ConstantExpression literal, Void context) { if (literal.getValue() == null) { return nullStatsEstimate(); } OptionalDouble doubleValue = toStatsRepresentation(metadata, session, literal.getType(), literal.getValue()); - SymbolStatsEstimate.Builder estimate = SymbolStatsEstimate.builder() + VariableStatsEstimate.Builder estimate = VariableStatsEstimate.builder() .setNullsFraction(0) .setDistinctValuesCount(1); @@ -166,24 +166,24 @@ public SymbolStatsEstimate visitConstant(ConstantExpression literal, Void contex } @Override - public SymbolStatsEstimate visitLambda(LambdaDefinitionExpression lambda, Void context) + public VariableStatsEstimate visitLambda(LambdaDefinitionExpression lambda, Void context) { - return SymbolStatsEstimate.unknown(); + return VariableStatsEstimate.unknown(); } @Override - public SymbolStatsEstimate visitVariableReference(VariableReferenceExpression reference, Void context) + public VariableStatsEstimate visitVariableReference(VariableReferenceExpression reference, Void context) { - return input.getSymbolStatistics(new Symbol(reference.getName())); + return input.getVariableStatistics(reference); } @Override - public SymbolStatsEstimate visitSpecialForm(SpecialFormExpression specialForm, Void context) + public VariableStatsEstimate visitSpecialForm(SpecialFormExpression specialForm, Void context) { if (specialForm.getForm().equals(COALESCE)) { - SymbolStatsEstimate result = null; + VariableStatsEstimate result = null; for (RowExpression operand : specialForm.getArguments()) { - SymbolStatsEstimate operandEstimates = operand.accept(this, context); + VariableStatsEstimate operandEstimates = operand.accept(this, context); if (result != null) { result = estimateCoalesce(input, result, operandEstimates); } @@ -193,13 +193,13 @@ public SymbolStatsEstimate visitSpecialForm(SpecialFormExpression specialForm, V } return requireNonNull(result, "result is null"); } - return SymbolStatsEstimate.unknown(); + return VariableStatsEstimate.unknown(); } - private SymbolStatsEstimate computeCastStatistics(CallExpression call, Void context) + private VariableStatsEstimate computeCastStatistics(CallExpression call, Void context) { requireNonNull(call, "call is null"); - SymbolStatsEstimate sourceStats = call.getArguments().get(0).accept(this, context); + VariableStatsEstimate sourceStats = call.getArguments().get(0).accept(this, context); // todo - make this general postprocessing rule. double distinctValuesCount = sourceStats.getDistinctValuesCount(); @@ -222,7 +222,7 @@ private SymbolStatsEstimate computeCastStatistics(CallExpression call, Void cont } } - return SymbolStatsEstimate.builder() + return VariableStatsEstimate.builder() .setNullsFraction(sourceStats.getNullsFraction()) .setLowValue(lowValue) .setHighValue(highValue) @@ -230,12 +230,12 @@ private SymbolStatsEstimate computeCastStatistics(CallExpression call, Void cont .build(); } - private SymbolStatsEstimate computeNegationStatistics(CallExpression call, Void context) + private VariableStatsEstimate computeNegationStatistics(CallExpression call, Void context) { requireNonNull(call, "call is null"); - SymbolStatsEstimate stats = call.getArguments().get(0).accept(this, context); + VariableStatsEstimate stats = call.getArguments().get(0).accept(this, context); if (resolution.isNegateFunction(call.getFunctionHandle())) { - return SymbolStatsEstimate.buildFrom(stats) + return VariableStatsEstimate.buildFrom(stats) .setLowValue(-stats.getHighValue()) .setHighValue(-stats.getLowValue()) .build(); @@ -243,13 +243,13 @@ private SymbolStatsEstimate computeNegationStatistics(CallExpression call, Void throw new IllegalStateException(format("Unexpected sign: %s(%s)" + call.getDisplayName(), call.getFunctionHandle())); } - private SymbolStatsEstimate computeArithmeticBinaryStatistics(CallExpression call, Void context) + private VariableStatsEstimate computeArithmeticBinaryStatistics(CallExpression call, Void context) { requireNonNull(call, "call is null"); - SymbolStatsEstimate left = call.getArguments().get(0).accept(this, context); - SymbolStatsEstimate right = call.getArguments().get(1).accept(this, context); + VariableStatsEstimate left = call.getArguments().get(0).accept(this, context); + VariableStatsEstimate right = call.getArguments().get(1).accept(this, context); - SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder() + VariableStatsEstimate.Builder result = VariableStatsEstimate.builder() .setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())) .setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()) .setDistinctValuesCount(min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), input.getOutputRowCount())); @@ -318,7 +318,7 @@ private double operate(OperatorType operator, double left, double right) } private class ExpressionStatsVisitor - extends AstVisitor + extends AstVisitor { private final PlanNodeStatsEstimate input; private final Session session; @@ -332,30 +332,31 @@ private class ExpressionStatsVisitor } @Override - protected SymbolStatsEstimate visitNode(Node node, Void context) + protected VariableStatsEstimate visitNode(Node node, Void context) { - return SymbolStatsEstimate.unknown(); + return VariableStatsEstimate.unknown(); } @Override - protected SymbolStatsEstimate visitSymbolReference(SymbolReference node, Void context) + protected VariableStatsEstimate visitSymbolReference(SymbolReference node, Void context) { - return input.getSymbolStatistics(Symbol.from(node)); + Symbol symbol = Symbol.from(node); + return input.getVariableStatistics(new VariableReferenceExpression(symbol.getName(), types.get(symbol))); } @Override - protected SymbolStatsEstimate visitNullLiteral(NullLiteral node, Void context) + protected VariableStatsEstimate visitNullLiteral(NullLiteral node, Void context) { return nullStatsEstimate(); } @Override - protected SymbolStatsEstimate visitLiteral(Literal node, Void context) + protected VariableStatsEstimate visitLiteral(Literal node, Void context) { Object value = evaluate(metadata, session.toConnectorSession(), node); Type type = ExpressionAnalyzer.createConstantAnalyzer(metadata, session, ImmutableList.of(), WarningCollector.NOOP).analyze(node, Scope.create()); OptionalDouble doubleValue = toStatsRepresentation(metadata, session, type, value); - SymbolStatsEstimate.Builder estimate = SymbolStatsEstimate.builder() + VariableStatsEstimate.Builder estimate = VariableStatsEstimate.builder() .setNullsFraction(0) .setDistinctValuesCount(1); @@ -367,7 +368,7 @@ protected SymbolStatsEstimate visitLiteral(Literal node, Void context) } @Override - protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) + protected VariableStatsEstimate visitFunctionCall(FunctionCall node, Void context) { Map, Type> expressionTypes = getExpressionTypes(session, node, types); ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(node, metadata, session, expressionTypes); @@ -379,11 +380,11 @@ protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) if (value instanceof Expression && !(value instanceof Literal)) { // value is not a constant - return SymbolStatsEstimate.unknown(); + return VariableStatsEstimate.unknown(); } // value is a constant - return SymbolStatsEstimate.builder() + return VariableStatsEstimate.builder() .setNullsFraction(0) .setDistinctValuesCount(1) .build(); @@ -405,9 +406,9 @@ private Map, Type> getExpressionTypes(Session session, Expre } @Override - protected SymbolStatsEstimate visitCast(Cast node, Void context) + protected VariableStatsEstimate visitCast(Cast node, Void context) { - SymbolStatsEstimate sourceStats = process(node.getExpression()); + VariableStatsEstimate sourceStats = process(node.getExpression()); TypeSignature targetType = TypeSignature.parseTypeSignature(node.getType()); // todo - make this general postprocessing rule. @@ -431,7 +432,7 @@ protected SymbolStatsEstimate visitCast(Cast node, Void context) } } - return SymbolStatsEstimate.builder() + return VariableStatsEstimate.builder() .setNullsFraction(sourceStats.getNullsFraction()) .setLowValue(lowValue) .setHighValue(highValue) @@ -440,14 +441,14 @@ protected SymbolStatsEstimate visitCast(Cast node, Void context) } @Override - protected SymbolStatsEstimate visitArithmeticUnary(ArithmeticUnaryExpression node, Void context) + protected VariableStatsEstimate visitArithmeticUnary(ArithmeticUnaryExpression node, Void context) { - SymbolStatsEstimate stats = process(node.getValue()); + VariableStatsEstimate stats = process(node.getValue()); switch (node.getSign()) { case PLUS: return stats; case MINUS: - return SymbolStatsEstimate.buildFrom(stats) + return VariableStatsEstimate.buildFrom(stats) .setLowValue(-stats.getHighValue()) .setHighValue(-stats.getLowValue()) .build(); @@ -457,13 +458,13 @@ protected SymbolStatsEstimate visitArithmeticUnary(ArithmeticUnaryExpression nod } @Override - protected SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) + protected VariableStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) { requireNonNull(node, "node is null"); - SymbolStatsEstimate left = process(node.getLeft()); - SymbolStatsEstimate right = process(node.getRight()); + VariableStatsEstimate left = process(node.getLeft()); + VariableStatsEstimate right = process(node.getRight()); - SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder() + VariableStatsEstimate.Builder result = VariableStatsEstimate.builder() .setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())) .setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()) .setDistinctValuesCount(min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), input.getOutputRowCount())); @@ -529,12 +530,12 @@ private double operate(ArithmeticBinaryExpression.Operator operator, double left } @Override - protected SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression node, Void context) + protected VariableStatsEstimate visitCoalesceExpression(CoalesceExpression node, Void context) { requireNonNull(node, "node is null"); - SymbolStatsEstimate result = null; + VariableStatsEstimate result = null; for (Expression operand : node.getOperands()) { - SymbolStatsEstimate operandEstimates = process(operand); + VariableStatsEstimate operandEstimates = process(operand); if (result != null) { result = estimateCoalesce(input, result, operandEstimates); } @@ -546,7 +547,7 @@ protected SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression node, V } } - private static SymbolStatsEstimate estimateCoalesce(PlanNodeStatsEstimate input, SymbolStatsEstimate left, SymbolStatsEstimate right) + private static VariableStatsEstimate estimateCoalesce(PlanNodeStatsEstimate input, VariableStatsEstimate left, VariableStatsEstimate right) { // Question to reviewer: do you have a method to check if fraction is empty or saturated? if (left.getNullsFraction() == 0) { @@ -556,7 +557,7 @@ else if (left.getNullsFraction() == 1.0) { return right; } else { - return SymbolStatsEstimate.builder() + return VariableStatsEstimate.builder() .setLowValue(min(left.getLowValue(), right.getLowValue())) .setHighValue(max(left.getHighValue(), right.getHighValue())) .setDistinctValuesCount(left.getDistinctValuesCount() + @@ -568,9 +569,9 @@ else if (left.getNullsFraction() == 1.0) { } } - private static SymbolStatsEstimate nullStatsEstimate() + private static VariableStatsEstimate nullStatsEstimate() { - return SymbolStatsEstimate.builder() + return VariableStatsEstimate.builder() .setDistinctValuesCount(0) .setNullsFraction(1) .build(); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/SemiJoinStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/SemiJoinStatsCalculator.java index 2d3ce5123e836..86d0a09bab270 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/SemiJoinStatsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/SemiJoinStatsCalculator.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import java.util.function.BiFunction; @@ -30,16 +30,16 @@ private SemiJoinStatsCalculator() {} // TODO implementation does not take into account overlapping of ranges for source and filtering source. // Basically it works as low and high values were the same for source and filteringSource and just looks at NDVs. - public static PlanNodeStatsEstimate computeSemiJoin(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate filteringSourceStats, Symbol sourceJoinSymbol, Symbol filteringSourceJoinSymbol) + public static PlanNodeStatsEstimate computeSemiJoin(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate filteringSourceStats, VariableReferenceExpression sourceJoinVariable, VariableReferenceExpression filteringSourceJoinVariable) { - return compute(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol, + return compute(sourceStats, filteringSourceStats, sourceJoinVariable, filteringSourceJoinVariable, (sourceJoinSymbolStats, filteringSourceJoinSymbolStats) -> min(filteringSourceJoinSymbolStats.getDistinctValuesCount(), sourceJoinSymbolStats.getDistinctValuesCount())); } - public static PlanNodeStatsEstimate computeAntiJoin(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate filteringSourceStats, Symbol sourceJoinSymbol, Symbol filteringSourceJoinSymbol) + public static PlanNodeStatsEstimate computeAntiJoin(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate filteringSourceStats, VariableReferenceExpression sourceJoinVariable, VariableReferenceExpression filteringSourceJoinVariable) { - return compute(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol, + return compute(sourceStats, filteringSourceStats, sourceJoinVariable, filteringSourceJoinVariable, (sourceJoinSymbolStats, filteringSourceJoinSymbolStats) -> max(sourceJoinSymbolStats.getDistinctValuesCount() * MIN_ANTI_JOIN_FILTER_COEFFICIENT, sourceJoinSymbolStats.getDistinctValuesCount() - filteringSourceJoinSymbolStats.getDistinctValuesCount())); @@ -48,15 +48,15 @@ public static PlanNodeStatsEstimate computeAntiJoin(PlanNodeStatsEstimate source private static PlanNodeStatsEstimate compute( PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate filteringSourceStats, - Symbol sourceJoinSymbol, - Symbol filteringSourceJoinSymbol, - BiFunction retainedNdvProvider) + VariableReferenceExpression sourceJoinVariable, + VariableReferenceExpression filteringSourceJoinVariable, + BiFunction retainedNdvProvider) { - SymbolStatsEstimate sourceJoinSymbolStats = sourceStats.getSymbolStatistics(sourceJoinSymbol); - SymbolStatsEstimate filteringSourceJoinSymbolStats = filteringSourceStats.getSymbolStatistics(filteringSourceJoinSymbol); + VariableStatsEstimate sourceJoinSymbolStats = sourceStats.getVariableStatistics(sourceJoinVariable); + VariableStatsEstimate filteringSourceJoinSymbolStats = filteringSourceStats.getVariableStatistics(filteringSourceJoinVariable); double retainedNdv = retainedNdvProvider.apply(sourceJoinSymbolStats, filteringSourceJoinSymbolStats); - SymbolStatsEstimate newSourceJoinSymbolStats = SymbolStatsEstimate.buildFrom(sourceJoinSymbolStats) + VariableStatsEstimate newSourceJoinSymbolStats = VariableStatsEstimate.buildFrom(sourceJoinSymbolStats) .setNullsFraction(0) .setDistinctValuesCount(retainedNdv) .build(); @@ -64,7 +64,7 @@ private static PlanNodeStatsEstimate compute( double sourceDistinctValuesCount = sourceJoinSymbolStats.getDistinctValuesCount(); if (sourceDistinctValuesCount == 0) { return PlanNodeStatsEstimate.buildFrom(sourceStats) - .addSymbolStatistics(sourceJoinSymbol, newSourceJoinSymbolStats) + .addVariableStatistics(sourceJoinVariable, newSourceJoinSymbolStats) .setOutputRowCount(0) .build(); } @@ -72,7 +72,7 @@ private static PlanNodeStatsEstimate compute( double filterFactor = sourceJoinSymbolStats.getValuesFraction() * retainedNdv / sourceDistinctValuesCount; double outputRowCount = sourceStats.getOutputRowCount() * filterFactor; return PlanNodeStatsEstimate.buildFrom(sourceStats) - .addSymbolStatistics(sourceJoinSymbol, newSourceJoinSymbolStats) + .addVariableStatistics(sourceJoinVariable, newSourceJoinSymbolStats) .setOutputRowCount(outputRowCount) .build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/SimpleFilterProjectSemiJoinStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/SimpleFilterProjectSemiJoinStatsRule.java index dbdd50bd56cf2..7f2dff68a8adc 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/SimpleFilterProjectSemiJoinStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/SimpleFilterProjectSemiJoinStatsRule.java @@ -20,7 +20,6 @@ import com.facebook.presto.spi.relation.LogicalRowExpressions; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.FilterNode; @@ -109,16 +108,16 @@ private Optional calculate(FilterNode filterNode, SemiJoi { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(semiJoinNode.getSource()); PlanNodeStatsEstimate filteringSourceStats = statsProvider.getStats(semiJoinNode.getFilteringSource()); - Symbol filteringSourceJoinSymbol = semiJoinNode.getFilteringSourceJoinSymbol(); - Symbol sourceJoinSymbol = semiJoinNode.getSourceJoinSymbol(); + VariableReferenceExpression filteringSourceJoinVariable = semiJoinNode.getFilteringSourceJoinVariable(); + VariableReferenceExpression sourceJoinVariable = semiJoinNode.getSourceJoinVariable(); Optional semiJoinOutputFilter; - Symbol semiJoinOutput = semiJoinNode.getSemiJoinOutput(); + VariableReferenceExpression semiJoinOutput = semiJoinNode.getSemiJoinOutput(); if (isExpression(filterNode.getPredicate())) { semiJoinOutputFilter = extractSemiJoinOutputFilter(castToExpression(filterNode.getPredicate()), semiJoinOutput); } else { - semiJoinOutputFilter = extractSemiJoinOutputFilter(filterNode.getPredicate(), new VariableReferenceExpression(semiJoinOutput.getName(), types.get(semiJoinOutput))); + semiJoinOutputFilter = extractSemiJoinOutputFilter(filterNode.getPredicate(), semiJoinOutput); } if (!semiJoinOutputFilter.isPresent()) { @@ -127,10 +126,10 @@ private Optional calculate(FilterNode filterNode, SemiJoi PlanNodeStatsEstimate semiJoinStats; if (semiJoinOutputFilter.get().isNegated()) { - semiJoinStats = computeAntiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol); + semiJoinStats = computeAntiJoin(sourceStats, filteringSourceStats, sourceJoinVariable, filteringSourceJoinVariable); } else { - semiJoinStats = computeSemiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol); + semiJoinStats = computeSemiJoin(sourceStats, filteringSourceStats, sourceJoinVariable, filteringSourceJoinVariable); } if (semiJoinStats.isOutputRowCountUnknown()) { @@ -152,7 +151,7 @@ private Optional calculate(FilterNode filterNode, SemiJoi return Optional.of(filteredStats); } - private Optional extractSemiJoinOutputFilter(Expression predicate, Symbol semiJoinOutput) + private Optional extractSemiJoinOutputFilter(Expression predicate, VariableReferenceExpression semiJoinOutput) { List conjuncts = extractConjuncts(predicate); List semiJoinOutputReferences = conjuncts.stream() @@ -196,9 +195,9 @@ private boolean isSemiJoinOutputReference(RowExpression conjunct, RowExpression return conjunct.equals(input) || (isNotFunction(conjunct) && ((CallExpression) conjunct).getArguments().get(0).equals(input)); } - private static boolean isSemiJoinOutputReference(Expression conjunct, Symbol semiJoinOutput) + private static boolean isSemiJoinOutputReference(Expression conjunct, VariableReferenceExpression semiJoinOutput) { - SymbolReference semiJoinOuputSymbolReference = semiJoinOutput.toSymbolReference(); + SymbolReference semiJoinOuputSymbolReference = new SymbolReference(semiJoinOutput.getName()); return conjunct.equals(semiJoinOuputSymbolReference) || (conjunct instanceof NotExpression && ((NotExpression) conjunct).getValue().equals(semiJoinOuputSymbolReference)); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/SimpleStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/SimpleStatsRule.java index 8cf61e4882dbb..9af1613273cc5 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/SimpleStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/SimpleStatsRule.java @@ -37,7 +37,7 @@ protected SimpleStatsRule(StatsNormalizer normalizer) public final Optional calculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) { return doCalculate(node, sourceStats, lookup, session, types) - .map(estimate -> normalizer.normalize(estimate, node.getOutputSymbols(), types)); + .map(estimate -> normalizer.normalize(estimate, node.getOutputVariables())); } protected abstract Optional doCalculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/SpatialJoinStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/SpatialJoinStatsRule.java index 8a2482f1da3c9..fb73ab05463ad 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/SpatialJoinStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/SpatialJoinStatsRule.java @@ -66,8 +66,8 @@ private PlanNodeStatsEstimate crossJoinStats(SpatialJoinNode node, PlanNodeStats PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder() .setOutputRowCount(leftStats.getOutputRowCount() * rightStats.getOutputRowCount()); - node.getLeft().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics(symbol, leftStats.getSymbolStatistics(symbol))); - node.getRight().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics(symbol, rightStats.getSymbolStatistics(symbol))); + node.getLeft().getOutputVariables().forEach(variable -> builder.addVariableStatistics(variable, leftStats.getVariableStatistics(variable))); + node.getRight().getOutputVariables().forEach(variable -> builder.addVariableStatistics(variable, rightStats.getVariableStatistics(variable))); return builder.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java b/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java index fd12f36df79ad..b613f655dc01d 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java @@ -62,7 +62,7 @@ public static StatisticRange empty() return new StatisticRange(NaN, NaN, 0); } - public static StatisticRange from(SymbolStatsEstimate estimate) + public static StatisticRange from(VariableStatsEstimate estimate) { return new StatisticRange(estimate.getLowValue(), estimate.getHighValue(), estimate.getDistinctValuesCount()); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatsNormalizer.java b/presto-main/src/main/java/com/facebook/presto/cost/StatsNormalizer.java index 42d57d6f56caf..9c40bbe2c12f4 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/StatsNormalizer.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/StatsNormalizer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.cost; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.DateType; @@ -21,8 +22,6 @@ import com.facebook.presto.spi.type.SmallintType; import com.facebook.presto.spi.type.TinyintType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TypeProvider; import com.google.common.collect.ImmutableSet; import java.util.Collection; @@ -35,24 +34,23 @@ import static java.lang.Double.isNaN; import static java.lang.Math.floor; import static java.lang.Math.pow; -import static java.util.Objects.requireNonNull; /** * Makes stats consistent */ public class StatsNormalizer { - public PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, TypeProvider types) + public PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats) { - return normalize(stats, Optional.empty(), types); + return normalize(stats, Optional.empty()); } - public PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Collection outputSymbols, TypeProvider types) + public PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Collection outputVariables) { - return normalize(stats, Optional.of(outputSymbols), types); + return normalize(stats, Optional.of(outputVariables)); } - private PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Optional> outputSymbols, TypeProvider types) + private PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Optional> outputVariables) { if (stats.isOutputRowCountUnknown()) { return PlanNodeStatsEstimate.unknown(); @@ -60,25 +58,25 @@ private PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Optional symbolFilter = outputSymbols + Predicate variableFilter = outputVariables .map(ImmutableSet::copyOf) - .map(set -> (Predicate) set::contains) - .orElse(symbol -> true); + .map(set -> (Predicate) set::contains) + .orElse(variable -> true); - for (Symbol symbol : stats.getSymbolsWithKnownStatistics()) { - if (!symbolFilter.test(symbol)) { - normalized.removeSymbolStatistics(symbol); + for (VariableReferenceExpression variable : stats.getVariablesWithKnownStatistics()) { + if (!variableFilter.test(variable)) { + normalized.removeVariableStatistics(variable); continue; } - SymbolStatsEstimate symbolStats = stats.getSymbolStatistics(symbol); - SymbolStatsEstimate normalizedSymbolStats = stats.getOutputRowCount() == 0 ? SymbolStatsEstimate.zero() : normalizeSymbolStats(symbol, symbolStats, stats, types); + VariableStatsEstimate variableStats = stats.getVariableStatistics(variable); + VariableStatsEstimate normalizedSymbolStats = stats.getOutputRowCount() == 0 ? VariableStatsEstimate.zero() : normalizeVariableStats(variable, variableStats, stats); if (normalizedSymbolStats.isUnknown()) { - normalized.removeSymbolStatistics(symbol); + normalized.removeVariableStatistics(variable); continue; } - if (!Objects.equals(normalizedSymbolStats, symbolStats)) { - normalized.addSymbolStatistics(symbol, normalizedSymbolStats); + if (!Objects.equals(normalizedSymbolStats, variableStats)) { + normalized.addVariableStatistics(variable, normalizedSymbolStats); } } @@ -88,20 +86,20 @@ private PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Optional 0, "outputRowCount must be greater than zero: %s", outputRowCount); - double distinctValuesCount = symbolStats.getDistinctValuesCount(); - double nullsFraction = symbolStats.getNullsFraction(); + double distinctValuesCount = variableStats.getDistinctValuesCount(); + double nullsFraction = variableStats.getNullsFraction(); if (!isNaN(distinctValuesCount)) { - Type type = requireNonNull(types.get(symbol), () -> "type is missing for symbol " + symbol); - double maxDistinctValuesByLowHigh = maxDistinctValuesByLowHigh(symbolStats, type); + Type type = variable.getType(); + double maxDistinctValuesByLowHigh = maxDistinctValuesByLowHigh(variableStats, type); if (distinctValuesCount > maxDistinctValuesByLowHigh) { distinctValuesCount = maxDistinctValuesByLowHigh; } @@ -120,18 +118,18 @@ private SymbolStatsEstimate normalizeSymbolStats(Symbol symbol, SymbolStatsEstim } if (distinctValuesCount == 0.0) { - return SymbolStatsEstimate.zero(); + return VariableStatsEstimate.zero(); } - return SymbolStatsEstimate.buildFrom(symbolStats) + return VariableStatsEstimate.buildFrom(variableStats) .setDistinctValuesCount(distinctValuesCount) .setNullsFraction(nullsFraction) .build(); } - private double maxDistinctValuesByLowHigh(SymbolStatsEstimate symbolStats, Type type) + private double maxDistinctValuesByLowHigh(VariableStatsEstimate variableStats, Type type) { - if (symbolStats.statisticRange().length() == 0.0) { + if (variableStats.statisticRange().length() == 0.0) { return 1; } @@ -139,7 +137,7 @@ private double maxDistinctValuesByLowHigh(SymbolStatsEstimate symbolStats, Type return NaN; } - double length = symbolStats.getHighValue() - symbolStats.getLowValue(); + double length = variableStats.getHighValue() - variableStats.getLowValue(); if (isNaN(length)) { return NaN; } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java index 3913cecf3dca5..305e076aa0c26 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java @@ -18,9 +18,9 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.statistics.ColumnStatistics; import com.facebook.presto.spi.statistics.TableStatistics; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.TableScanNode; @@ -58,26 +58,25 @@ protected Optional doCalculate(TableScanNode node, StatsP Constraint constraint = new Constraint<>(node.getCurrentConstraint()); TableStatistics tableStatistics = metadata.getTableStatistics(session, node.getTable(), constraint); - Map outputSymbolStats = new HashMap<>(); + Map outputVariableStats = new HashMap<>(); - for (Map.Entry entry : node.getAssignments().entrySet()) { - Symbol symbol = entry.getKey(); + for (Map.Entry entry : node.getAssignments().entrySet()) { Optional columnStatistics = Optional.ofNullable(tableStatistics.getColumnStatistics().get(entry.getValue())); - outputSymbolStats.put(symbol, columnStatistics.map(statistics -> toSymbolStatistics(tableStatistics, statistics)).orElse(SymbolStatsEstimate.unknown())); + outputVariableStats.put(entry.getKey(), columnStatistics.map(statistics -> toSymbolStatistics(tableStatistics, statistics)).orElse(VariableStatsEstimate.unknown())); } return Optional.of(PlanNodeStatsEstimate.builder() .setOutputRowCount(tableStatistics.getRowCount().getValue()) - .addSymbolStatistics(outputSymbolStats) + .addVariableStatistics(outputVariableStats) .build()); } - private SymbolStatsEstimate toSymbolStatistics(TableStatistics tableStatistics, ColumnStatistics columnStatistics) + private VariableStatsEstimate toSymbolStatistics(TableStatistics tableStatistics, ColumnStatistics columnStatistics) { double nullsFraction = columnStatistics.getNullsFraction().getValue(); double nonNullRowsCount = tableStatistics.getRowCount().getValue() * (1.0 - nullsFraction); double averageRowSize = nonNullRowsCount == 0 ? 0 : columnStatistics.getDataSize().getValue() / nonNullRowsCount; - SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder(); + VariableStatsEstimate.Builder result = VariableStatsEstimate.builder(); result.setNullsFraction(nullsFraction); result.setDistinctValuesCount(columnStatistics.getDistinctValuesCount().getValue()); result.setAverageRowSize(averageRowSize); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java index d3405c3ca32e1..239677ffae390 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java @@ -16,7 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -55,7 +55,7 @@ protected final Optional doCalculate(UnionNode node, Stat PlanNode source = node.getSources().get(i); PlanNodeStatsEstimate sourceStats = statsProvider.getStats(source); - PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputSymbols(sourceStats, node.getSymbolMapping(), i); + PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputSymbols(sourceStats, node.getVariableMapping(), i); if (estimate.isPresent()) { estimate = Optional.of(addStatsAndCollapseDistinctValues(estimate.get(), sourceStatsWithMappedSymbols)); @@ -68,13 +68,13 @@ protected final Optional doCalculate(UnionNode node, Stat return estimate; } - private PlanNodeStatsEstimate mapToOutputSymbols(PlanNodeStatsEstimate estimate, ListMultimap mapping, int index) + private PlanNodeStatsEstimate mapToOutputSymbols(PlanNodeStatsEstimate estimate, ListMultimap mapping, int index) { PlanNodeStatsEstimate.Builder mapped = PlanNodeStatsEstimate.builder() .setOutputRowCount(estimate.getOutputRowCount()); mapping.keySet().stream() - .forEach(symbol -> mapped.addSymbolStatistics(symbol, estimate.getSymbolStatistics(mapping.get(symbol).get(index)))); + .forEach(variable -> mapped.addVariableStatistics(variable, estimate.getVariableStatistics(mapping.get(variable).get(index)))); return mapped.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/UnnestStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/UnnestStatsRule.java index 340d2e291bbcd..3ee92f0046d6c 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/UnnestStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/UnnestStatsRule.java @@ -16,7 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.ComposableStatsCalculator.Rule; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -52,28 +52,28 @@ public Optional calculate(UnnestNode node, StatsProvider // Thus we'd still populate the inaccurate numbers just so stats are populated to enable optimization // potential. calculatedStats.setOutputRowCount(sourceStats.getOutputRowCount()); - for (Symbol symbol : node.getReplicateSymbols()) { - calculatedStats.addSymbolStatistics(symbol, sourceStats.getSymbolStatistics(symbol)); + for (VariableReferenceExpression variable : node.getReplicateVariables()) { + calculatedStats.addVariableStatistics(variable, sourceStats.getVariableStatistics(variable)); } - for (Map.Entry> entry : node.getUnnestSymbols().entrySet()) { - List unnestToSymbols = entry.getValue(); - SymbolStatsEstimate stats = sourceStats.getSymbolStatistics(entry.getKey()); - for (Symbol symbol : unnestToSymbols) { + for (Map.Entry> entry : node.getUnnestVariables().entrySet()) { + List unnestToVariables = entry.getValue(); + VariableStatsEstimate stats = sourceStats.getVariableStatistics(entry.getKey()); + for (VariableReferenceExpression variable : unnestToVariables) { // This is a very conservative way on estimating stats after unnest. We assume each symbol // after unnest would have as much data as the symbol before unnest. This would over // estimate, which are more likely to mean we'd loose an optimization opportunity, but at // least it won't cause false optimizations. - calculatedStats.addSymbolStatistics( - symbol, - SymbolStatsEstimate.builder() + calculatedStats.addVariableStatistics( + variable, + VariableStatsEstimate.builder() .setAverageRowSize(stats.getAverageRowSize()) .build()); } } - if (node.getOrdinalitySymbol().isPresent()) { - calculatedStats.addSymbolStatistics( - node.getOrdinalitySymbol().get(), - SymbolStatsEstimate.builder() + if (node.getOrdinalityVariable().isPresent()) { + calculatedStats.addVariableStatistics( + node.getOrdinalityVariable().get(), + VariableStatsEstimate.builder() .setLowValue(0) .setNullsFraction(0) .build()); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java index 7da57cc4438aa..e271631b5ae40 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java @@ -17,8 +17,8 @@ import com.facebook.presto.cost.ComposableStatsCalculator.Rule; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.ValuesNode; @@ -65,18 +65,18 @@ public Optional calculate(ValuesNode node, StatsProvider PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); statsBuilder.setOutputRowCount(node.getRows().size()); - for (int symbolId = 0; symbolId < node.getOutputSymbols().size(); ++symbolId) { - Symbol symbol = node.getOutputSymbols().get(symbolId); - List symbolValues = getSymbolValues(node, symbolId, session, types.get(symbol)); - statsBuilder.addSymbolStatistics(symbol, buildSymbolStatistics(symbolValues, session, types.get(symbol))); + for (int variableId = 0; variableId < node.getOutputVariables().size(); ++variableId) { + VariableReferenceExpression variable = node.getOutputVariables().get(variableId); + List symbolValues = getVariableValues(node, variableId, session, variable.getType()); + statsBuilder.addVariableStatistics(variable, buildVariableStatistics(symbolValues, session, variable.getType())); } return Optional.of(statsBuilder.build()); } - private List getSymbolValues(ValuesNode valuesNode, int symbolId, Session session, Type symbolType) + private List getVariableValues(ValuesNode valuesNode, int symbolId, Session session, Type type) { - if (UNKNOWN.equals(symbolType)) { + if (UNKNOWN.equals(type)) { // special casing for UNKNOWN as evaluateConstantExpression does not handle that return IntStream.range(0, valuesNode.getRows().size()) .mapToObj(rowId -> null) @@ -86,21 +86,21 @@ private List getSymbolValues(ValuesNode valuesNode, int symbolId, Sessio .map(row -> row.get(symbolId)) .map(rowExpression -> { if (isExpression(rowExpression)) { - return evaluateConstantExpression(castToExpression(rowExpression), symbolType, metadata, session, ImmutableList.of()); + return evaluateConstantExpression(castToExpression(rowExpression), type, metadata, session, ImmutableList.of()); } return evaluateConstantRowExpression(rowExpression, metadata, session.toConnectorSession()); }) .collect(toList()); } - private SymbolStatsEstimate buildSymbolStatistics(List values, Session session, Type type) + private VariableStatsEstimate buildVariableStatistics(List values, Session session, Type type) { List nonNullValues = values.stream() .filter(Objects::nonNull) .collect(toImmutableList()); if (nonNullValues.isEmpty()) { - return SymbolStatsEstimate.zero(); + return VariableStatsEstimate.zero(); } double[] valuesAsDoubles = nonNullValues.stream() @@ -115,7 +115,7 @@ private SymbolStatsEstimate buildSymbolStatistics(List values, Session s double nonNullValuesCount = nonNullValues.size(); long distinctValuesCount = nonNullValues.stream().distinct().count(); - return SymbolStatsEstimate.builder() + return VariableStatsEstimate.builder() .setNullsFraction((valuesCount - nonNullValuesCount) / valuesCount) .setLowValue(lowValue) .setHighValue(highValue) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/SymbolStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/VariableStatsEstimate.java similarity index 88% rename from presto-main/src/main/java/com/facebook/presto/cost/SymbolStatsEstimate.java rename to presto-main/src/main/java/com/facebook/presto/cost/VariableStatsEstimate.java index 900dbdf11ef52..2e13cf5a0fde7 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/SymbolStatsEstimate.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/VariableStatsEstimate.java @@ -28,10 +28,10 @@ import static java.lang.Double.isNaN; import static java.lang.String.format; -public class SymbolStatsEstimate +public class VariableStatsEstimate { - private static final SymbolStatsEstimate UNKNOWN = new SymbolStatsEstimate(NEGATIVE_INFINITY, POSITIVE_INFINITY, NaN, NaN, NaN); - private static final SymbolStatsEstimate ZERO = new SymbolStatsEstimate(NaN, NaN, 1.0, 0.0, 0.0); + private static final VariableStatsEstimate UNKNOWN = new VariableStatsEstimate(NEGATIVE_INFINITY, POSITIVE_INFINITY, NaN, NaN, NaN); + private static final VariableStatsEstimate ZERO = new VariableStatsEstimate(NaN, NaN, 1.0, 0.0, 0.0); // for now we support only types which map to real domain naturally and keep low/high value as double in stats. private final double lowValue; @@ -40,18 +40,18 @@ public class SymbolStatsEstimate private final double averageRowSize; private final double distinctValuesCount; - public static SymbolStatsEstimate unknown() + public static VariableStatsEstimate unknown() { return UNKNOWN; } - public static SymbolStatsEstimate zero() + public static VariableStatsEstimate zero() { return ZERO; } @JsonCreator - public SymbolStatsEstimate( + public VariableStatsEstimate( @JsonProperty("lowValue") double lowValue, @JsonProperty("highValue") double highValue, @JsonProperty("nullsFraction") double nullsFraction, @@ -121,12 +121,12 @@ public double getDistinctValuesCount() return distinctValuesCount; } - public SymbolStatsEstimate mapNullsFraction(Function mappingFunction) + public VariableStatsEstimate mapNullsFraction(Function mappingFunction) { return buildFrom(this).setNullsFraction(mappingFunction.apply(nullsFraction)).build(); } - public SymbolStatsEstimate mapDistinctValuesCount(Function mappingFunction) + public VariableStatsEstimate mapDistinctValuesCount(Function mappingFunction) { return buildFrom(this).setDistinctValuesCount(mappingFunction.apply(distinctValuesCount)).build(); } @@ -152,7 +152,7 @@ public boolean equals(Object o) if (o == null || getClass() != o.getClass()) { return false; } - SymbolStatsEstimate that = (SymbolStatsEstimate) o; + VariableStatsEstimate that = (VariableStatsEstimate) o; return Double.compare(nullsFraction, that.nullsFraction) == 0 && Double.compare(averageRowSize, that.averageRowSize) == 0 && Double.compare(distinctValuesCount, that.distinctValuesCount) == 0 && @@ -182,7 +182,7 @@ public static Builder builder() return new Builder(); } - public static Builder buildFrom(SymbolStatsEstimate other) + public static Builder buildFrom(VariableStatsEstimate other) { return builder() .setLowValue(other.getLowValue()) @@ -237,9 +237,9 @@ public Builder setDistinctValuesCount(double distinctValuesCount) return this; } - public SymbolStatsEstimate build() + public VariableStatsEstimate build() { - return new SymbolStatsEstimate(lowValue, highValue, nullsFraction, averageRowSize, distinctValuesCount); + return new VariableStatsEstimate(lowValue, highValue, nullsFraction, averageRowSize, distinctValuesCount); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlTaskExecutionFactory.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlTaskExecutionFactory.java index d1f92222bd69b..bd22e001f64df 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlTaskExecutionFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlTaskExecutionFactory.java @@ -87,7 +87,7 @@ public SqlTaskExecution create( localExecutionPlan = planner.plan( taskContext, fragment.getRoot(), - TypeProvider.copyOf(fragment.getSymbols()), + TypeProvider.fromVariables(fragment.getVariables()), fragment.getPartitioningScheme(), fragment.getStageExecutionDescriptor(), fragment.getTableScanSchedulingOrder(), diff --git a/presto-main/src/main/java/com/facebook/presto/operator/InterpretedHashGenerator.java b/presto-main/src/main/java/com/facebook/presto/operator/InterpretedHashGenerator.java index a306e6cf0cd4a..2d75575b2240d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/InterpretedHashGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/InterpretedHashGenerator.java @@ -17,13 +17,13 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.optimizations.HashGenerationOptimizer; import com.facebook.presto.type.TypeUtils; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.function.IntFunction; +import static com.facebook.presto.sql.planner.optimizations.HashGenerationOptimizer.INITIAL_HASH_VALUE; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -54,7 +54,7 @@ public long hashPosition(int position, Page page) public long hashPosition(int position, IntFunction blockProvider) { - long result = HashGenerationOptimizer.INITIAL_HASH_VALUE; + long result = INITIAL_HASH_VALUE; for (int i = 0; i < hashChannels.length; i++) { Type type = hashChannelTypes.get(i); result = CombineHashFunction.getHash(result, TypeUtils.hashPosition(type, blockProvider.apply(hashChannels[i]), position)); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/LookupSourceFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/LookupSourceFactory.java index 9bfa551300f06..9c328fb6bcd73 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/LookupSourceFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/LookupSourceFactory.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.operator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.Symbol; import com.google.common.util.concurrent.ListenableFuture; import java.util.List; @@ -53,7 +53,7 @@ default ListenableFuture>> finishP @Override OuterPositionIterator getOuterPositionIterator(); - Map getLayout(); + Map getLayout(); // this is only here for the index lookup source default void setTaskContext(TaskContext taskContext) {} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java index 7a4073d17b8c0..30bd096edf4be 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java @@ -16,8 +16,8 @@ import com.facebook.presto.operator.LookupSourceProvider.LookupSourceLease; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.Symbol; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -58,7 +58,7 @@ public final class PartitionedLookupSourceFactory { private final List types; private final List outputTypes; - private final Map layout; + private final Map layout; private final List hashChannelTypes; private final boolean outer; private final SpilledLookupSource spilledLookupSource; @@ -106,7 +106,7 @@ public final class PartitionedLookupSourceFactory */ private final ConcurrentHashMap suppliedLookupSources = new ConcurrentHashMap<>(); - public PartitionedLookupSourceFactory(List types, List outputTypes, List hashChannelTypes, int partitionCount, Map layout, boolean outer) + public PartitionedLookupSourceFactory(List types, List outputTypes, List hashChannelTypes, int partitionCount, Map layout, boolean outer) { checkArgument(Integer.bitCount(partitionCount) == 1, "partitionCount must be a power of 2"); @@ -133,7 +133,7 @@ public List getOutputTypes() } @Override - public Map getLayout() + public Map getLayout() { return layout; } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java index fa951efab812f..80c6d42a99dd3 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java @@ -24,7 +24,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.RunLengthEncodedBlock; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.spi.predicate.NullableValue; +import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; @@ -54,7 +54,7 @@ public static class PartitionedOutputFactory { private final PartitionFunction partitionFunction; private final List partitionChannels; - private final List> partitionConstants; + private final List> partitionConstants; private final OutputBuffer outputBuffer; private final boolean replicatesAnyRow; private final OptionalInt nullChannel; @@ -63,7 +63,7 @@ public static class PartitionedOutputFactory public PartitionedOutputFactory( PartitionFunction partitionFunction, List partitionChannels, - List> partitionConstants, + List> partitionConstants, boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, @@ -111,7 +111,7 @@ public static class PartitionedOutputOperatorFactory private final Function pagePreprocessor; private final PartitionFunction partitionFunction; private final List partitionChannels; - private final List> partitionConstants; + private final List> partitionConstants; private final boolean replicatesAnyRow; private final OptionalInt nullChannel; private final OutputBuffer outputBuffer; @@ -125,7 +125,7 @@ public PartitionedOutputOperatorFactory( Function pagePreprocessor, PartitionFunction partitionFunction, List partitionChannels, - List> partitionConstants, + List> partitionConstants, boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, @@ -201,7 +201,7 @@ public PartitionedOutputOperator( Function pagePreprocessor, PartitionFunction partitionFunction, List partitionChannels, - List> partitionConstants, + List> partitionConstants, boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, @@ -313,7 +313,7 @@ private static class PagePartitioner public PagePartitioner( PartitionFunction partitionFunction, List partitionChannels, - List> partitionConstants, + List> partitionConstants, boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, @@ -325,7 +325,7 @@ public PagePartitioner( this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels is null"); this.partitionConstants = requireNonNull(partitionConstants, "partitionConstants is null").stream() - .map(constant -> constant.map(NullableValue::asBlock)) + .map(constant -> constant.map(ConstantExpression::getValueBlock)) .collect(toImmutableList()); this.replicatesAnyRow = replicatesAnyRow; this.nullChannel = requireNonNull(nullChannel, "nullChannel is null"); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java index 8cd0510859c7a..12ffdc49cc2f0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java @@ -19,9 +19,9 @@ import com.facebook.presto.operator.PagesIndex; import com.facebook.presto.operator.StaticLookupSourceProvider; import com.facebook.presto.operator.TaskContext; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.JoinCompiler; -import com.facebook.presto.sql.planner.Symbol; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.Futures; @@ -44,7 +44,7 @@ public class IndexLookupSourceFactory implements LookupSourceFactory { private final List outputTypes; - private final Map layout; + private final Map layout; private final Supplier indexLoaderSupplier; private TaskContext taskContext; private final SettableFuture whenTaskContextSet = SettableFuture.create(); @@ -54,7 +54,7 @@ public IndexLookupSourceFactory( List keyOutputChannels, OptionalInt keyOutputHashChannel, List outputTypes, - Map layout, + Map layout, IndexBuildDriverFactoryProvider indexBuildDriverFactoryProvider, DataSize maxIndexMemorySize, IndexJoinLookupStats stats, @@ -87,7 +87,7 @@ public List getOutputTypes() } @Override - public Map getLayout() + public Map getLayout() { return layout; } diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index b32003fc72770..2fd6b34dc9189 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -87,6 +87,7 @@ import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; @@ -106,6 +107,8 @@ import com.facebook.presto.sql.Serialization.ExpressionDeserializer; import com.facebook.presto.sql.Serialization.ExpressionSerializer; import com.facebook.presto.sql.Serialization.FunctionCallDeserializer; +import com.facebook.presto.sql.Serialization.VariableReferenceExpressionDeserializer; +import com.facebook.presto.sql.Serialization.VariableReferenceExpressionSerializer; import com.facebook.presto.sql.SqlEnvironmentConfig; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.gen.ExpressionCompiler; @@ -377,6 +380,10 @@ protected void setup(Binder binder) jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); newSetBinder(binder, Type.class); + // plan + jsonBinder(binder).addKeySerializerBinding(VariableReferenceExpression.class).to(VariableReferenceExpressionSerializer.class); + jsonBinder(binder).addKeyDeserializerBinding(VariableReferenceExpression.class).to(VariableReferenceExpressionDeserializer.class); + // split manager binder.bind(SplitManager.class).in(Scopes.SINGLETON); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java index 72d0b829148c0..16564a94c398a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java @@ -13,9 +13,11 @@ */ package com.facebook.presto.sql; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -258,24 +260,24 @@ public static boolean referencesAny(Expression expression, Collection va return variables.stream().anyMatch(references::contains); } - public static Function expressionOrNullSymbols(final Predicate... nullSymbolScopes) + public static Function expressionOrNullVariables(TypeProvider types, final Predicate... nullVariableScopes) { return expression -> { ImmutableList.Builder resultDisjunct = ImmutableList.builder(); resultDisjunct.add(expression); - for (Predicate nullSymbolScope : nullSymbolScopes) { - List symbols = SymbolsExtractor.extractUnique(expression).stream() - .filter(nullSymbolScope) + for (Predicate nullVariableScope : nullVariableScopes) { + List variables = SymbolsExtractor.extractUniqueVariable(expression, types).stream() + .filter(nullVariableScope) .collect(toImmutableList()); - if (Iterables.isEmpty(symbols)) { + if (Iterables.isEmpty(variables)) { continue; } ImmutableList.Builder nullConjuncts = ImmutableList.builder(); - for (Symbol symbol : symbols) { - nullConjuncts.add(new IsNullPredicate(symbol.toSymbolReference())); + for (VariableReferenceExpression variable : variables) { + nullConjuncts.add(new IsNullPredicate(new SymbolReference(variable.getName()))); } resultDisjunct.add(and(nullConjuncts.build())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/Serialization.java b/presto-main/src/main/java/com/facebook/presto/sql/Serialization.java index dee60036f4827..9289aa4c3a908 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/Serialization.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/Serialization.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.sql; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; @@ -21,6 +23,7 @@ import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.KeyDeserializer; import com.fasterxml.jackson.databind.SerializerProvider; import javax.inject.Inject; @@ -28,7 +31,9 @@ import java.io.IOException; import java.util.Optional; +import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; +import static java.lang.String.format; public final class Serialization { @@ -82,4 +87,38 @@ public FunctionCall deserialize(JsonParser jsonParser, DeserializationContext de return (FunctionCall) rewriteIdentifiersToSymbolReferences(sqlParser.createExpression(jsonParser.getText())); } } + + public static class VariableReferenceExpressionSerializer + extends JsonSerializer + { + @Override + public void serialize(VariableReferenceExpression value, JsonGenerator jsonGenerator, SerializerProvider serializers) + throws IOException + { + jsonGenerator.writeFieldName(format("%s(%s)", value.getName(), value.getType())); + } + } + + public static class VariableReferenceExpressionDeserializer + extends KeyDeserializer + { + private final TypeManager typeManager; + + @Inject + public VariableReferenceExpressionDeserializer(TypeManager typeManager) + { + this.typeManager = typeManager; + } + + @Override + public Object deserializeKey(String key, DeserializationContext ctxt) + throws IOException + { + int p = key.indexOf("("); + if (p <= 0 || key.charAt(key.length() - 1) != ')') { + throw new IllegalArgumentException(format("Expect key to be of format 'name(type)', found %s", key)); + } + return new VariableReferenceExpression(key.substring(0, p), typeManager.getType(parseTypeSignature(key.substring(p + 1, key.length() - 1)))); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java index fb17ccb9defc0..4d3d57afd0ac5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; @@ -53,7 +54,7 @@ import java.util.function.Predicate; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; -import static com.facebook.presto.sql.ExpressionUtils.expressionOrNullSymbols; +import static com.facebook.presto.sql.ExpressionUtils.expressionOrNullVariables; import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts; import static com.facebook.presto.sql.ExpressionUtils.filterDeterministicConjuncts; import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; @@ -70,12 +71,12 @@ */ public class EffectivePredicateExtractor { - private static final Predicate> SYMBOL_MATCHES_EXPRESSION = - entry -> entry.getValue().equals(entry.getKey().toSymbolReference()); + private static final Predicate> VARIABLE_MATCHES_EXPRESSION = + entry -> entry.getValue().equals(new SymbolReference(entry.getKey().getName())); - private static final Function, Expression> ENTRY_TO_EQUALITY = + private static final Function, Expression> VARIABLE_ENTRY_TO_EQUALITY = entry -> { - SymbolReference reference = entry.getKey().toSymbolReference(); + SymbolReference reference = new SymbolReference(entry.getKey().getName()); Expression expression = entry.getValue(); // TODO: this is not correct with respect to NULLs ('reference IS NULL' would be correct, rather than 'reference = NULL') // TODO: switch this to 'IS NOT DISTINCT FROM' syntax when EqualityInference properly supports it @@ -89,19 +90,21 @@ public EffectivePredicateExtractor(ExpressionDomainTranslator domainTranslator) this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null"); } - public Expression extract(PlanNode node) + public Expression extract(PlanNode node, TypeProvider types) { - return node.accept(new Visitor(domainTranslator), null); + return node.accept(new Visitor(domainTranslator, types), null); } private static class Visitor extends InternalPlanVisitor { private final ExpressionDomainTranslator domainTranslator; + private final TypeProvider types; - public Visitor(ExpressionDomainTranslator domainTranslator) + public Visitor(ExpressionDomainTranslator domainTranslator, TypeProvider types) { this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null"); + this.types = requireNonNull(types, "types is null"); } @Override @@ -124,7 +127,7 @@ public Expression visitAggregation(AggregationNode node, Void context) Expression underlyingPredicate = node.getSource().accept(this, context); - return pullExpressionThroughSymbols(underlyingPredicate, node.getGroupingKeys()); + return pullExpressionThroughVariables(underlyingPredicate, node.getGroupingKeys()); } @Override @@ -144,11 +147,11 @@ public Expression visitFilter(FilterNode node, Void context) public Expression visitExchange(ExchangeNode node, Void context) { return deriveCommonPredicates(node, source -> { - Map mappings = new HashMap<>(); + Map mappings = new HashMap<>(); for (int i = 0; i < node.getInputs().get(source).size(); i++) { mappings.put( - node.getOutputSymbols().get(i), - node.getInputs().get(source).get(i).toSymbolReference()); + node.getOutputVariables().get(i), + new SymbolReference(node.getInputs().get(source).get(i).getName())); } return mappings.entrySet(); }); @@ -162,16 +165,16 @@ public Expression visitProject(ProjectNode node, Void context) Expression underlyingPredicate = node.getSource().accept(this, context); List projectionEqualities = node.getAssignments().entrySet().stream() - .filter(SYMBOL_MATCHES_EXPRESSION.negate()) - .map(ENTRY_TO_EQUALITY) + .filter(VARIABLE_MATCHES_EXPRESSION.negate()) + .map(VARIABLE_ENTRY_TO_EQUALITY) .collect(toImmutableList()); - return pullExpressionThroughSymbols(combineConjuncts( + return pullExpressionThroughVariables(combineConjuncts( ImmutableList.builder() .addAll(projectionEqualities) .add(underlyingPredicate) .build()), - node.getOutputSymbols()); + node.getOutputVariables()); } @Override @@ -201,8 +204,8 @@ public Expression visitDistinctLimit(DistinctLimitNode node, Void context) @Override public Expression visitTableScan(TableScanNode node, Void context) { - Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); - return domainTranslator.toPredicate(node.getCurrentConstraint().simplify().transform(assignments::get)); + Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); + return domainTranslator.toPredicate(node.getCurrentConstraint().simplify().transform(column -> assignments.containsKey(column) ? new Symbol(assignments.get(column).getName()) : null)); } @Override @@ -220,7 +223,7 @@ public Expression visitWindow(WindowNode node, Void context) @Override public Expression visitUnion(UnionNode node, Void context) { - return deriveCommonPredicates(node, source -> Multimaps.transformValues(node.outputSymbolMap(source), Symbol::toSymbolReference).entries()); + return deriveCommonPredicates(node, source -> Multimaps.transformValues(node.outputMap(source), variable -> new SymbolReference(variable.getName())).entries()); } @Override @@ -235,42 +238,42 @@ public Expression visitJoin(JoinNode node, Void context) switch (node.getType()) { case INNER: - return pullExpressionThroughSymbols(combineConjuncts(ImmutableList.builder() + return pullExpressionThroughVariables(combineConjuncts(ImmutableList.builder() .add(leftPredicate) .add(rightPredicate) .add(combineConjuncts(joinConjuncts)) .add(node.getFilter().map(OriginalExpressionUtils::castToExpression).orElse(TRUE_LITERAL)) - .build()), node.getOutputSymbols()); + .build()), node.getOutputVariables()); case LEFT: return combineConjuncts(ImmutableList.builder() - .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) - .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) - .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) + .add(pullExpressionThroughVariables(leftPredicate, node.getOutputVariables())) + .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputVariables(), node.getRight().getOutputVariables()::contains)) + .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputVariables(), node.getRight().getOutputVariables()::contains)) .build()); case RIGHT: return combineConjuncts(ImmutableList.builder() - .add(pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())) - .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) - .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) + .add(pullExpressionThroughVariables(rightPredicate, node.getOutputVariables())) + .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputVariables(), node.getLeft().getOutputVariables()::contains)) + .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputVariables(), node.getLeft().getOutputVariables()::contains)) .build()); case FULL: return combineConjuncts(ImmutableList.builder() - .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) - .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) - .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains, node.getRight().getOutputSymbols()::contains)) + .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputVariables(), node.getLeft().getOutputVariables()::contains)) + .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputVariables(), node.getRight().getOutputVariables()::contains)) + .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputVariables(), node.getLeft().getOutputVariables()::contains, node.getRight().getOutputVariables()::contains)) .build()); default: throw new UnsupportedOperationException("Unknown join type: " + node.getType()); } } - private static Iterable pullNullableConjunctsThroughOuterJoin(List conjuncts, Collection outputSymbols, Predicate... nullSymbolScopes) + private Iterable pullNullableConjunctsThroughOuterJoin(List conjuncts, Collection outputVariables, Predicate... nullVariableScopes) { // Conjuncts without any symbol dependencies cannot be applied to the effective predicate (e.g. FALSE literal) return conjuncts.stream() - .map(expression -> pullExpressionThroughSymbols(expression, outputSymbols)) + .map(expression -> pullExpressionThroughVariables(expression, outputVariables)) .map(expression -> SymbolsExtractor.extractAll(expression).isEmpty() ? TRUE_LITERAL : expression) - .map(expressionOrNullSymbols(nullSymbolScopes)) + .map(expressionOrNullVariables(types, nullVariableScopes)) .collect(toImmutableList()); } @@ -290,20 +293,20 @@ public Expression visitSpatialJoin(SpatialJoinNode node, Void context) switch (node.getType()) { case INNER: return combineConjuncts(ImmutableList.builder() - .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) - .add(pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())) + .add(pullExpressionThroughVariables(leftPredicate, node.getOutputVariables())) + .add(pullExpressionThroughVariables(rightPredicate, node.getOutputVariables())) .build()); case LEFT: return combineConjuncts(ImmutableList.builder() - .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) - .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) + .add(pullExpressionThroughVariables(leftPredicate, node.getOutputVariables())) + .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputVariables(), node.getRight().getOutputVariables()::contains)) .build()); default: throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType()); } } - private Expression deriveCommonPredicates(PlanNode node, Function>> mapping) + private Expression deriveCommonPredicates(PlanNode node, Function>> mapping) { // Find the predicates that can be pulled up from each source List> sourceOutputConjuncts = new ArrayList<>(); @@ -311,16 +314,16 @@ private Expression deriveCommonPredicates(PlanNode node, Function equalities = mapping.apply(i).stream() - .filter(SYMBOL_MATCHES_EXPRESSION.negate()) - .map(ENTRY_TO_EQUALITY) + .filter(VARIABLE_MATCHES_EXPRESSION.negate()) + .map(VARIABLE_ENTRY_TO_EQUALITY) .collect(toImmutableList()); - sourceOutputConjuncts.add(ImmutableSet.copyOf(extractConjuncts(pullExpressionThroughSymbols(combineConjuncts( + sourceOutputConjuncts.add(ImmutableSet.copyOf(extractConjuncts(pullExpressionThroughVariables(combineConjuncts( ImmutableList.builder() .addAll(equalities) .add(underlyingPredicate) .build()), - node.getOutputSymbols())))); + node.getOutputVariables())))); } // Find the intersection of predicates across all sources @@ -334,28 +337,21 @@ private Expression deriveCommonPredicates(PlanNode node, Function pullExpressionsThroughSymbols(List expressions, Collection symbols) - { - return expressions.stream() - .map(expression -> pullExpressionThroughSymbols(expression, symbols)) - .collect(toImmutableList()); - } - - private static Expression pullExpressionThroughSymbols(Expression expression, Collection symbols) + private Expression pullExpressionThroughVariables(Expression expression, Collection variables) { EqualityInference equalityInference = createEqualityInference(expression); ImmutableList.Builder effectiveConjuncts = ImmutableList.builder(); for (Expression conjunct : EqualityInference.nonInferrableConjuncts(expression)) { if (ExpressionDeterminismEvaluator.isDeterministic(conjunct)) { - Expression rewritten = equalityInference.rewriteExpression(conjunct, in(symbols)); + Expression rewritten = equalityInference.rewriteExpression(conjunct, in(variables), types); if (rewritten != null) { effectiveConjuncts.add(rewritten); } } } - effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(symbols)).getScopeEqualities()); + effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(variables), types).getScopeEqualities()); return combineConjuncts(effectiveConjuncts.build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java index 8199b7de7aedc..b5651cb3fdd2a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; @@ -96,10 +97,10 @@ private EqualityInference(Iterable> equalityGroups, Set symbolScope) + public Expression rewriteExpression(Expression expression, Predicate variableScope, TypeProvider types) { checkArgument(isDeterministic(expression), "Only deterministic expressions may be considered for rewrite"); - return rewriteExpression(expression, symbolScope, true); + return rewriteExpression(expression, variableScope, true, types); } /** @@ -107,12 +108,12 @@ public Expression rewriteExpression(Expression expression, Predicate sym * given the known equalities. Returns null if unsuccessful. * This method allows rewriting non-deterministic expressions. */ - public Expression rewriteExpressionAllowNonDeterministic(Expression expression, Predicate symbolScope) + public Expression rewriteExpressionAllowNonDeterministic(Expression expression, Predicate variableScope, TypeProvider types) { - return rewriteExpression(expression, symbolScope, true); + return rewriteExpression(expression, variableScope, true, types); } - private Expression rewriteExpression(Expression expression, Predicate symbolScope, boolean allowFullReplacement) + private Expression rewriteExpression(Expression expression, Predicate variableScope, boolean allowFullReplacement, TypeProvider types) { Iterable subExpressions = SubExpressionExtractor.extract(expression); if (!allowFullReplacement) { @@ -121,7 +122,7 @@ private Expression rewriteExpression(Expression expression, Predicate sy ImmutableMap.Builder expressionRemap = ImmutableMap.builder(); for (Expression subExpression : subExpressions) { - Expression canonical = getScopedCanonical(subExpression, symbolScope); + Expression canonical = getScopedCanonical(subExpression, variableScope, types); if (canonical != null) { expressionRemap.put(subExpression, canonical); } @@ -131,7 +132,7 @@ private Expression rewriteExpression(Expression expression, Predicate sy // larger subtrees over smaller subtrees // TODO: this rewrite can probably be made more sophisticated Expression rewritten = ExpressionTreeRewriter.rewriteWith(new ExpressionNodeInliner(expressionRemap.build()), expression); - if (!symbolToExpressionPredicate(symbolScope).apply(rewritten)) { + if (!variableToExpressionPredicate(variableScope, types).apply(rewritten)) { // If the rewritten is still not compliant with the symbol scope, just give up return null; } @@ -139,7 +140,7 @@ private Expression rewriteExpression(Expression expression, Predicate sy } /** - * Dumps the inference equalities as equality expressions that are partitioned by the symbolScope. + * Dumps the inference equalities as equality expressions that are partitioned by the variableScope. * All stored equalities are returned in a compact set and will be classified into three groups as determined by the symbol scope: *
    *
  1. equalities that fit entirely within the symbol scope
  2. @@ -166,7 +167,7 @@ private Expression rewriteExpression(Expression expression, Predicate sy * d = f * */ - public EqualityPartition generateEqualitiesPartitionedBy(Predicate symbolScope) + public EqualityPartition generateEqualitiesPartitionedBy(Predicate variableScope, TypeProvider types) { ImmutableSet.Builder scopeEqualities = ImmutableSet.builder(); ImmutableSet.Builder scopeComplementEqualities = ImmutableSet.builder(); @@ -179,11 +180,11 @@ public EqualityPartition generateEqualitiesPartitionedBy(Predicate symbo // Try to push each non-derived expression into one side of the scope for (Expression expression : filter(equalitySet, not(derivedExpressions::contains))) { - Expression scopeRewritten = rewriteExpression(expression, symbolScope, false); + Expression scopeRewritten = rewriteExpression(expression, variableScope, false, types); if (scopeRewritten != null) { scopeExpressions.add(scopeRewritten); } - Expression scopeComplementRewritten = rewriteExpression(expression, not(symbolScope), false); + Expression scopeComplementRewritten = rewriteExpression(expression, not(variableScope), false, types); if (scopeComplementRewritten != null) { scopeComplementExpressions.add(scopeComplementRewritten); } @@ -234,22 +235,22 @@ private static Expression getCanonical(Iterable expressions) } /** - * Returns a canonical expression that is fully contained by the symbolScope and that is equivalent + * Returns a canonical expression that is fully contained by the variableScope and that is equivalent * to the specified expression. Returns null if unable to to find a canonical. */ @VisibleForTesting - Expression getScopedCanonical(Expression expression, Predicate symbolScope) + Expression getScopedCanonical(Expression expression, Predicate variableScope, TypeProvider types) { Expression canonicalIndex = canonicalMap.get(expression); if (canonicalIndex == null) { return null; } - return getCanonical(filter(equalitySets.get(canonicalIndex), symbolToExpressionPredicate(symbolScope))); + return getCanonical(filter(equalitySets.get(canonicalIndex), variableToExpressionPredicate(variableScope, types))); } - private static Predicate symbolToExpressionPredicate(final Predicate symbolScope) + private static Predicate variableToExpressionPredicate(final Predicate variableScope, TypeProvider types) { - return expression -> Iterables.all(SymbolsExtractor.extractUnique(expression), symbolScope); + return expression -> Iterables.all(SymbolsExtractor.extractUniqueVariable(expression, types), variableScope); } /** diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java index 3af1384d4a104..9608eb396dbc6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java @@ -100,9 +100,6 @@ public Void visitAggregation(AggregationNode node, ImmutableList.Builder mapping, Expression expression, TypeProvider types) + { + return inlineVariables(mapping::get, expression, types); + } + + public static Expression inlineVariables(Function mapping, Expression expression, TypeProvider types) + { + return new ExpressionVariableInliner(mapping, types).rewrite(expression); + } + + private final Function mapping; + private final TypeProvider types; + + private ExpressionVariableInliner(Function mapping, TypeProvider types) + { + this.mapping = mapping; + this.types = types; + } + + private Expression rewrite(Expression expression) + { + return ExpressionTreeRewriter.rewriteWith(new ExpressionVariableInliner.Visitor(), expression); + } + + private class Visitor + extends ExpressionRewriter + { + private final Set excludedNames = new HashSet<>(); + + @Override + public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (excludedNames.contains(node.getName())) { + return node; + } + + Expression expression = mapping.apply(new VariableReferenceExpression(node.getName(), types.get(new Symbol(node.getName())))); + checkState(expression != null, "Cannot resolve symbol %s", node.getName()); + return expression; + } + + @Override + public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + for (LambdaArgumentDeclaration argument : node.getArguments()) { + String argumentName = argument.getName().getValue(); + // Variable names are unique. As a result, a variable should never be excluded multiple times. + checkArgument(!excludedNames.contains(argumentName)); + excludedNames.add(argumentName); + } + Expression result = treeRewriter.defaultRewrite(node, context); + for (LambdaArgumentDeclaration argument : node.getArguments()) { + excludedNames.remove(argument.getName().getValue()); + } + return result; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/GroupingOperationRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/GroupingOperationRewriter.java index 8ad4290fc2ac2..ae049ef49111c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/GroupingOperationRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/GroupingOperationRewriter.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FieldId; import com.facebook.presto.sql.analyzer.RelationId; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; @@ -23,6 +24,7 @@ import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.SubscriptExpression; +import com.facebook.presto.sql.tree.SymbolReference; import java.util.List; import java.util.Map; @@ -38,9 +40,9 @@ public final class GroupingOperationRewriter { private GroupingOperationRewriter() {} - public static Expression rewriteGroupingOperation(GroupingOperation expression, List> groupingSets, Map, FieldId> columnReferenceFields, Optional groupIdSymbol) + public static Expression rewriteGroupingOperation(GroupingOperation expression, List> groupingSets, Map, FieldId> columnReferenceFields, Optional groupIdVariable) { - requireNonNull(groupIdSymbol, "groupIdSymbol is null"); + requireNonNull(groupIdVariable, "groupIdVariable is null"); // No GroupIdNode and a GROUPING() operation imply a single grouping, which // means that any columns specified as arguments to GROUPING() will be included @@ -51,7 +53,7 @@ public static Expression rewriteGroupingOperation(GroupingOperation expression, return new LongLiteral("0"); } else { - checkState(groupIdSymbol.isPresent(), "groupId symbol is missing"); + checkState(groupIdVariable.isPresent(), "groupId symbol is missing"); RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getRelationId(); @@ -70,7 +72,7 @@ public static Expression rewriteGroupingOperation(GroupingOperation expression, // It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1 return new SubscriptExpression( new ArrayConstructor(groupingResults), - new ArithmeticBinaryExpression(ADD, groupIdSymbol.get().toSymbolReference(), new GenericLiteral("BIGINT", "1"))); + new ArithmeticBinaryExpression(ADD, new SymbolReference(groupIdVariable.get().getName()), new GenericLiteral("BIGINT", "1"))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index fd58856bdb8af..5deda36a31069 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -111,8 +111,8 @@ import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.InputReferenceExpression; import com.facebook.presto.spi.relation.LambdaDefinitionExpression; import com.facebook.presto.spi.relation.RowExpression; @@ -175,7 +175,7 @@ import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.planner.plan.WindowNode.Frame; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; -import com.facebook.presto.sql.relational.SymbolToChannelTranslator; +import com.facebook.presto.sql.relational.VariableToChannelTranslator; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; @@ -268,7 +268,6 @@ import static com.facebook.presto.util.SpatialJoinUtils.ST_WITHIN; import static com.facebook.presto.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static com.facebook.presto.util.SpatialJoinUtils.extractSupportedSpatialFunctions; -import static com.google.common.base.Functions.forMap; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; @@ -279,7 +278,6 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Range.closedOpen; import static io.airlift.units.DataSize.Unit.BYTE; -import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; @@ -371,7 +369,7 @@ public LocalExecutionPlan plan( OutputBuffer outputBuffer, TaskExchangeClientManager taskExchangeClientManager) { - List outputLayout = partitioningScheme.getOutputLayout(); + List outputLayout = partitioningScheme.getOutputLayout(); if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION) || partitioningScheme.getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION) || @@ -391,7 +389,7 @@ public LocalExecutionPlan plan( // We can convert the symbols directly into channels, because the root must be a sink and therefore the layout is fixed List partitionChannels; - List> partitionConstants; + List> partitionConstants; List partitionChannelTypes; if (partitioningScheme.getHashColumn().isPresent()) { partitionChannels = ImmutableList.of(outputLayout.indexOf(partitioningScheme.getHashColumn().get())); @@ -404,7 +402,7 @@ public LocalExecutionPlan plan( if (argument.isConstant()) { return -1; } - return outputLayout.indexOf(argument.getColumn()); + return outputLayout.indexOf(argument.getVariableReference()); }) .collect(toImmutableList()); partitionConstants = partitioningScheme.getPartitioning().getArguments().stream() @@ -412,7 +410,7 @@ public LocalExecutionPlan plan( if (argument.isConstant()) { return Optional.of(argument.getConstant()); } - return Optional.empty(); + return Optional.empty(); }) .collect(toImmutableList()); partitionChannelTypes = partitioningScheme.getPartitioning().getArguments().stream() @@ -427,7 +425,7 @@ public LocalExecutionPlan plan( PartitionFunction partitionFunction = nodePartitioningManager.getPartitionFunction(taskContext.getSession(), partitioningScheme, partitionChannelTypes); OptionalInt nullChannel = OptionalInt.empty(); - Set partitioningColumns = partitioningScheme.getPartitioning().getColumns(); + Set partitioningColumns = partitioningScheme.getPartitioning().getVariableReferences(); // partitioningColumns expected to have one column in the normal case, and zero columns when partitioning on a constant checkArgument(!partitioningScheme.isReplicateNullsAndAny() || partitioningColumns.size() <= 1); @@ -457,7 +455,7 @@ public LocalExecutionPlan plan( TaskContext taskContext, StageExecutionDescriptor stageExecutionDescriptor, PlanNode plan, - List outputLayout, + List outputLayout, TypeProvider types, List partitionedSourceOrder, OutputFactory outputOperatorFactory, @@ -471,7 +469,7 @@ public LocalExecutionPlan plan( Function pagePreprocessor = enforceLayoutProcessor(outputLayout, physicalOperation.getLayout()); List outputTypes = outputLayout.stream() - .map(types::get) + .map(VariableReferenceExpression::getType) .collect(toImmutableList()); context.addDriverFactory( @@ -660,14 +658,14 @@ public void setDriverInstanceCount(int driverInstanceCount) private static class IndexSourceContext { - private final SetMultimap indexLookupToProbeInput; + private final SetMultimap indexLookupToProbeInput; - public IndexSourceContext(SetMultimap indexLookupToProbeInput) + public IndexSourceContext(SetMultimap indexLookupToProbeInput) { this.indexLookupToProbeInput = ImmutableSetMultimap.copyOf(requireNonNull(indexLookupToProbeInput, "indexLookupToProbeInput is null")); } - private SetMultimap getIndexLookupToProbeInput() + private SetMultimap getIndexLookupToProbeInput() { return indexLookupToProbeInput; } @@ -732,8 +730,8 @@ private PhysicalOperation createMergeSource(RemoteSourceNode node, LocalExecutio context.setDriverInstanceCount(1); OrderingScheme orderingScheme = node.getOrderingScheme().get(); - ImmutableMap layout = makeLayout(node); - List sortChannels = getChannelsForSymbols(orderingScheme.getOrderBy(), layout); + ImmutableMap layout = makeLayout(node); + List sortChannels = getChannelsForVariables(orderingScheme.getOrderBy(), layout); List sortOrder = orderingScheme.getOrderingList(); List types = getSourceOperatorTypes(node, context.getTypes()); @@ -796,8 +794,7 @@ public PhysicalOperation visitRowNumber(RowNumberNode node, LocalExecutionPlanCo { PhysicalOperation source = node.getSource().accept(this, context); - List partitionBySymbols = node.getPartitionBy(); - List partitionChannels = getChannelsForSymbols(partitionBySymbols, source.getLayout()); + List partitionChannels = getChannelsForVariables(node.getPartitionBy(), source.getLayout()); List partitionTypes = partitionChannels.stream() .map(channel -> source.getTypes().get(channel)) @@ -809,14 +806,14 @@ public PhysicalOperation visitRowNumber(RowNumberNode node, LocalExecutionPlanCo } // compute the layout of the output from the window operator - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); outputMappings.putAll(source.getLayout()); // row number function goes in the last channel int channel = source.getTypes().size(); - outputMappings.put(node.getRowNumberSymbol(), channel); + outputMappings.put(node.getRowNumberVariable(), channel); - Optional hashChannel = node.getHashSymbol().map(channelGetter(source)); + Optional hashChannel = node.getHashVariable().map(variableChannelGetter(source)); OperatorFactory operatorFactory = new RowNumberOperator.RowNumberOperatorFactory( context.getNextOperatorId(), node.getId(), @@ -836,15 +833,14 @@ public PhysicalOperation visitTopNRowNumber(TopNRowNumberNode node, LocalExecuti { PhysicalOperation source = node.getSource().accept(this, context); - List partitionBySymbols = node.getPartitionBy(); - List partitionChannels = getChannelsForSymbols(partitionBySymbols, source.getLayout()); + List partitionChannels = getChannelsForVariables(node.getPartitionBy(), source.getLayout()); List partitionTypes = partitionChannels.stream() .map(channel -> source.getTypes().get(channel)) .collect(toImmutableList()); - List orderBySymbols = node.getOrderingScheme().getOrderBy(); - List sortChannels = getChannelsForSymbols(orderBySymbols, source.getLayout()); - List sortOrder = orderBySymbols.stream() + List orderByVariables = node.getOrderingScheme().getOrderBy(); + List sortChannels = getChannelsForVariables(orderByVariables, source.getLayout()); + List sortOrder = orderByVariables.stream() .map(symbol -> node.getOrderingScheme().getOrdering(symbol)) .collect(toImmutableList()); @@ -854,16 +850,16 @@ public PhysicalOperation visitTopNRowNumber(TopNRowNumberNode node, LocalExecuti } // compute the layout of the output from the window operator - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); outputMappings.putAll(source.getLayout()); if (!node.isPartial() || !partitionChannels.isEmpty()) { // row number function goes in the last channel int channel = source.getTypes().size(); - outputMappings.put(node.getRowNumberSymbol(), channel); + outputMappings.put(node.getRowNumberVariable(), channel); } - Optional hashChannel = node.getHashSymbol().map(channelGetter(source)); + Optional hashChannel = node.getHashVariable().map(variableChannelGetter(source)); OperatorFactory operatorFactory = new TopNRowNumberOperator.TopNRowNumberOperatorFactory( context.getNextOperatorId(), node.getId(), @@ -887,16 +883,15 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext { PhysicalOperation source = node.getSource().accept(this, context); - List partitionBySymbols = node.getPartitionBy(); - List partitionChannels = ImmutableList.copyOf(getChannelsForSymbols(partitionBySymbols, source.getLayout())); - List preGroupedChannels = ImmutableList.copyOf(getChannelsForSymbols(ImmutableList.copyOf(node.getPrePartitionedInputs()), source.getLayout())); + List partitionChannels = ImmutableList.copyOf(getChannelsForVariables(node.getPartitionBy(), source.getLayout())); + List preGroupedChannels = ImmutableList.copyOf(getChannelsForVariables(node.getPrePartitionedInputs(), source.getLayout())); List sortChannels = ImmutableList.of(); List sortOrder = ImmutableList.of(); if (node.getOrderingScheme().isPresent()) { OrderingScheme orderingScheme = node.getOrderingScheme().get(); - sortChannels = getChannelsForSymbols(orderingScheme.getOrderBy(), source.getLayout()); + sortChannels = getChannelsForVariables(orderingScheme.getOrderBy(), source.getLayout()); sortOrder = orderingScheme.getOrderingList(); } @@ -906,8 +901,8 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext } ImmutableList.Builder windowFunctionsBuilder = ImmutableList.builder(); - ImmutableList.Builder windowFunctionOutputSymbolsBuilder = ImmutableList.builder(); - for (Map.Entry entry : node.getWindowFunctions().entrySet()) { + ImmutableList.Builder windowFunctionOutputVariablesBuilder = ImmutableList.builder(); + for (Map.Entry entry : node.getWindowFunctions().entrySet()) { Optional frameStartChannel = Optional.empty(); Optional frameEndChannel = Optional.empty(); @@ -926,29 +921,28 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext ImmutableList.Builder arguments = ImmutableList.builder(); for (RowExpression argument : call.getArguments()) { checkState(argument instanceof VariableReferenceExpression); - Symbol argumentSymbol = new Symbol(((VariableReferenceExpression) argument).getName()); - arguments.add(source.getLayout().get(argumentSymbol)); + arguments.add(source.getLayout().get(argument)); } - Symbol symbol = entry.getKey(); + VariableReferenceExpression variable = entry.getKey(); FunctionManager functionManager = metadata.getFunctionManager(); WindowFunctionSupplier windowFunctionSupplier = functionManager.getWindowFunctionImplementation(functionHandle); Type type = metadata.getType(functionManager.getFunctionMetadata(functionHandle).getReturnType()); windowFunctionsBuilder.add(window(windowFunctionSupplier, type, frameInfo, arguments.build())); - windowFunctionOutputSymbolsBuilder.add(symbol); + windowFunctionOutputVariablesBuilder.add(variable); } - List windowFunctionOutputSymbols = windowFunctionOutputSymbolsBuilder.build(); + List windowFunctionOutputVariables = windowFunctionOutputVariablesBuilder.build(); // compute the layout of the output from the window operator - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); - for (Symbol symbol : node.getSource().getOutputSymbols()) { - outputMappings.put(symbol, source.getLayout().get(symbol)); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + for (VariableReferenceExpression variable : node.getSource().getOutputVariables()) { + outputMappings.put(variable, source.getLayout().get(variable)); } // window functions go in remaining channels starting after the last channel from the source operator, one per channel int channel = source.getTypes().size(); - for (Symbol symbol : windowFunctionOutputSymbols) { - outputMappings.put(symbol, channel); + for (VariableReferenceExpression variable : windowFunctionOutputVariables) { + outputMappings.put(variable, channel); channel++; } @@ -974,13 +968,13 @@ public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext cont { PhysicalOperation source = node.getSource().accept(this, context); - List orderBySymbols = node.getOrderingScheme().getOrderBy(); + List orderByVariables = node.getOrderingScheme().getOrderBy(); List sortChannels = new ArrayList<>(); List sortOrders = new ArrayList<>(); - for (Symbol symbol : orderBySymbols) { - sortChannels.add(source.getLayout().get(symbol)); - sortOrders.add(node.getOrderingScheme().getOrdering(symbol)); + for (VariableReferenceExpression variable : orderByVariables) { + sortChannels.add(source.getLayout().get(variable)); + sortOrders.add(node.getOrderingScheme().getOrdering(variable)); } OperatorFactory operator = new TopNOperatorFactory( @@ -999,13 +993,13 @@ public PhysicalOperation visitSort(SortNode node, LocalExecutionPlanContext cont { PhysicalOperation source = node.getSource().accept(this, context); - List orderBySymbols = node.getOrderingScheme().getOrderBy(); + List orderByVariables = node.getOrderingScheme().getOrderBy(); - List orderByChannels = getChannelsForSymbols(orderBySymbols, source.getLayout()); + List orderByChannels = getChannelsForVariables(orderByVariables, source.getLayout()); ImmutableList.Builder sortOrder = ImmutableList.builder(); - for (Symbol symbol : orderBySymbols) { - sortOrder.add(node.getOrderingScheme().getOrdering(symbol)); + for (VariableReferenceExpression variable : orderByVariables) { + sortOrder.add(node.getOrderingScheme().getOrdering(variable)); } ImmutableList.Builder outputChannels = ImmutableList.builder(); @@ -1040,8 +1034,8 @@ public PhysicalOperation visitDistinctLimit(DistinctLimitNode node, LocalExecuti { PhysicalOperation source = node.getSource().accept(this, context); - Optional hashChannel = node.getHashSymbol().map(channelGetter(source)); - List distinctChannels = getChannelsForSymbols(node.getDistinctSymbols(), source.getLayout()); + Optional hashChannel = node.getHashVariable().map(variableChannelGetter(source)); + List distinctChannels = getChannelsForVariables(node.getDistinctVariables(), source.getLayout()); OperatorFactory operatorFactory = new DistinctLimitOperatorFactory( context.getNextOperatorId(), @@ -1058,18 +1052,18 @@ public PhysicalOperation visitDistinctLimit(DistinctLimitNode node, LocalExecuti public PhysicalOperation visitGroupId(GroupIdNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); - Map newLayout = new HashMap<>(); + Map newLayout = new HashMap<>(); ImmutableList.Builder outputTypes = ImmutableList.builder(); int outputChannel = 0; - for (Symbol output : node.getGroupingSets().stream().flatMap(Collection::stream).collect(Collectors.toSet())) { + for (VariableReferenceExpression output : node.getGroupingSets().stream().flatMap(Collection::stream).collect(Collectors.toSet())) { newLayout.put(output, outputChannel++); outputTypes.add(source.getTypes().get(source.getLayout().get(node.getGroupingColumns().get(output)))); } - Map argumentMappings = new HashMap<>(); - for (Symbol output : node.getAggregationArguments()) { + Map argumentMappings = new HashMap<>(); + for (VariableReferenceExpression output : node.getAggregationArguments()) { int inputChannel = source.getLayout().get(output); newLayout.put(output, outputChannel++); @@ -1079,21 +1073,21 @@ public PhysicalOperation visitGroupId(GroupIdNode node, LocalExecutionPlanContex // for every grouping set, create a mapping of all output to input channels (including arguments) ImmutableList.Builder> mappings = ImmutableList.builder(); - for (List groupingSet : node.getGroupingSets()) { + for (List groupingSet : node.getGroupingSets()) { ImmutableMap.Builder setMapping = ImmutableMap.builder(); - for (Symbol output : groupingSet) { + for (VariableReferenceExpression output : groupingSet) { setMapping.put(newLayout.get(output), source.getLayout().get(node.getGroupingColumns().get(output))); } - for (Symbol output : argumentMappings.keySet()) { + for (VariableReferenceExpression output : argumentMappings.keySet()) { setMapping.put(newLayout.get(output), argumentMappings.get(output)); } mappings.add(setMapping.build()); } - newLayout.put(node.getGroupIdSymbol(), outputChannel); + newLayout.put(node.getGroupIdVariable(), outputChannel); outputTypes.add(BIGINT); OperatorFactory groupIdOperatorFactory = new GroupIdOperator.GroupIdOperatorFactory(context.getNextOperatorId(), @@ -1124,8 +1118,8 @@ public PhysicalOperation visitMarkDistinct(MarkDistinctNode node, LocalExecution { PhysicalOperation source = node.getSource().accept(this, context); - List channels = getChannelsForSymbols(node.getDistinctSymbols(), source.getLayout()); - Optional hashChannel = node.getHashSymbol().map(channelGetter(source)); + List channels = getChannelsForVariables(node.getDistinctVariables(), source.getLayout()); + Optional hashChannel = node.getHashVariable().map(variableChannelGetter(source)); MarkDistinctOperatorFactory operator = new MarkDistinctOperatorFactory(context.getNextOperatorId(), node.getId(), source.getTypes(), channels, hashChannel, joinCompiler); return new PhysicalOperation(operator, makeLayout(node), context, source); } @@ -1147,9 +1141,9 @@ public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext PlanNode sourceNode = node.getSource(); RowExpression filterExpression = node.getPredicate(); - List outputSymbols = node.getOutputSymbols(); + List outputVariables = node.getOutputVariables(); - return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputSymbols), outputSymbols); + return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputVariables), outputVariables); } @Override @@ -1166,9 +1160,7 @@ public PhysicalOperation visitProject(ProjectNode node, LocalExecutionPlanContex sourceNode = node.getSource(); } - List outputSymbols = node.getOutputSymbols(); - - return visitScanFilterAndProject(context, node.getId(), sourceNode, filterExpression, node.getAssignments(), outputSymbols); + return visitScanFilterAndProject(context, node.getId(), sourceNode, filterExpression, node.getAssignments(), node.getOutputVariables()); } // TODO: This should be refactored, so that there's an optimizer that merges scan-filter-project into a single PlanNode @@ -1178,11 +1170,11 @@ private PhysicalOperation visitScanFilterAndProject( PlanNode sourceNode, Optional filterExpression, Assignments assignments, - List outputSymbols) + List outputVariables) { // if source is a table scan we fold it directly into the filter and project // otherwise we plan it as a normal operator - Map sourceLayout; + Map sourceLayout; List columns = null; PhysicalOperation source = null; if (sourceNode instanceof TableScanNode) { @@ -1192,13 +1184,11 @@ private PhysicalOperation visitScanFilterAndProject( sourceLayout = new LinkedHashMap<>(); columns = new ArrayList<>(); int channel = 0; - for (Symbol symbol : tableScanNode.getOutputSymbols()) { - columns.add(tableScanNode.getAssignments().get(symbol)); + for (VariableReferenceExpression variable : tableScanNode.getOutputVariables()) { + columns.add(tableScanNode.getAssignments().get(variable)); Integer input = channel; - sourceLayout.put(symbol, input); - - Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol)); + sourceLayout.put(variable, input); channel++; } @@ -1216,18 +1206,18 @@ private PhysicalOperation visitScanFilterAndProject( } // build output mapping - ImmutableMap.Builder outputMappingsBuilder = ImmutableMap.builder(); - for (int i = 0; i < outputSymbols.size(); i++) { - Symbol symbol = outputSymbols.get(i); - outputMappingsBuilder.put(symbol, i); + ImmutableMap.Builder outputMappingsBuilder = ImmutableMap.builder(); + for (int i = 0; i < outputVariables.size(); i++) { + VariableReferenceExpression variable = outputVariables.get(i); + outputMappingsBuilder.put(variable, i); } - Map outputMappings = outputMappingsBuilder.build(); + Map outputMappings = outputMappingsBuilder.build(); // compiler uses inputs instead of symbols, so rewrite the expressions first List projections = new ArrayList<>(); - for (Symbol symbol : outputSymbols) { - projections.add(assignments.get(symbol)); + for (VariableReferenceExpression variable : outputVariables) { + projections.add(assignments.get(variable)); } Map, Type> expressionTypes = getExpressionTypes( @@ -1283,19 +1273,19 @@ private PhysicalOperation visitScanFilterAndProject( } // TODO: migrate `toRowExpression` to `bindChannels` - private RowExpression toRowExpression(Expression expression, Map, Type> types, Map sourceLayout) + private RowExpression toRowExpression(Expression expression, Map, Type> types, Map sourceLayout) { return SqlToRowExpressionTranslator.translate(expression, types, sourceLayout, metadata.getFunctionManager(), metadata.getTypeManager(), session, true); } - private RowExpression bindChannels(RowExpression expression, Map sourceLayout) + private RowExpression bindChannels(RowExpression expression, Map sourceLayout) { Type type = expression.getType(); Object value = new RowExpressionInterpreter(expression, metadata, session.toConnectorSession(), true).optimize(); if (value instanceof RowExpression) { RowExpression optimized = (RowExpression) value; // building channel info - expression = SymbolToChannelTranslator.translate(optimized, sourceLayout); + expression = VariableToChannelTranslator.translate(optimized, sourceLayout); } else { expression = constant(value, type); @@ -1307,8 +1297,8 @@ private RowExpression bindChannels(RowExpression expression, Map columns = new ArrayList<>(); - for (Symbol symbol : node.getOutputSymbols()) { - columns.add(node.getAssignments().get(symbol)); + for (VariableReferenceExpression variable : node.getOutputVariables()) { + columns.add(node.getAssignments().get(variable)); } OperatorFactory operatorFactory = new TableScanOperatorFactory(context.getNextOperatorId(), node.getId(), pageSourceProvider, columns); @@ -1326,7 +1316,7 @@ public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext return new PhysicalOperation(operatorFactory, makeLayout(node), context, UNGROUPED_EXECUTION); } - List outputTypes = getSymbolTypes(node.getOutputSymbols(), context.getTypes()); + List outputTypes = node.getOutputVariables().stream().map(VariableReferenceExpression::getType).collect(toImmutableList()); PageBuilder pageBuilder = new PageBuilder(node.getRows().size(), outputTypes); for (List row : node.getRows()) { pageBuilder.declarePosition(); @@ -1347,36 +1337,36 @@ public PhysicalOperation visitUnnest(UnnestNode node, LocalExecutionPlanContext PhysicalOperation source = node.getSource().accept(this, context); ImmutableList.Builder replicateTypes = ImmutableList.builder(); - for (Symbol symbol : node.getReplicateSymbols()) { - replicateTypes.add(context.getTypes().get(symbol)); + for (VariableReferenceExpression variable : node.getReplicateVariables()) { + replicateTypes.add(variable.getType()); } - List unnestSymbols = ImmutableList.copyOf(node.getUnnestSymbols().keySet()); + List unnestVariables = ImmutableList.copyOf(node.getUnnestVariables().keySet()); ImmutableList.Builder unnestTypes = ImmutableList.builder(); - for (Symbol symbol : unnestSymbols) { - unnestTypes.add(context.getTypes().get(symbol)); + for (VariableReferenceExpression variable : unnestVariables) { + unnestTypes.add(variable.getType()); } - Optional ordinalitySymbol = node.getOrdinalitySymbol(); - Optional ordinalityType = ordinalitySymbol.map(context.getTypes()::get); - ordinalityType.ifPresent(type -> checkState(type.equals(BIGINT), "Type of ordinalitySymbol must always be BIGINT.")); + Optional ordinalityVariable = node.getOrdinalityVariable(); + Optional ordinalityType = ordinalityVariable.map(VariableReferenceExpression::getType); + ordinalityType.ifPresent(type -> checkState(type.equals(BIGINT), "Type of ordinalityVariable must always be BIGINT.")); - List replicateChannels = getChannelsForSymbols(node.getReplicateSymbols(), source.getLayout()); - List unnestChannels = getChannelsForSymbols(unnestSymbols, source.getLayout()); + List replicateChannels = getChannelsForVariables(node.getReplicateVariables(), source.getLayout()); + List unnestChannels = getChannelsForVariables(unnestVariables, source.getLayout()); // Source channels are always laid out first, followed by the unnested symbols - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); int channel = 0; - for (Symbol symbol : node.getReplicateSymbols()) { - outputMappings.put(symbol, channel); + for (VariableReferenceExpression variable : node.getReplicateVariables()) { + outputMappings.put(variable, channel); channel++; } - for (Symbol symbol : unnestSymbols) { - for (Symbol unnestedSymbol : node.getUnnestSymbols().get(symbol)) { - outputMappings.put(unnestedSymbol, channel); + for (VariableReferenceExpression variable : unnestVariables) { + for (VariableReferenceExpression unnestedVariable : node.getUnnestVariables().get(variable)) { + outputMappings.put(unnestedVariable, channel); channel++; } } - if (ordinalitySymbol.isPresent()) { - outputMappings.put(ordinalitySymbol.get(), channel); + if (ordinalityVariable.isPresent()) { + outputMappings.put(ordinalityVariable.get(), channel); channel++; } OperatorFactory operatorFactory = new UnnestOperatorFactory( @@ -1390,17 +1380,17 @@ public PhysicalOperation visitUnnest(UnnestNode node, LocalExecutionPlanContext return new PhysicalOperation(operatorFactory, outputMappings.build(), context, source); } - private ImmutableMap makeLayout(PlanNode node) + private ImmutableMap makeLayout(PlanNode node) { - return makeLayoutFromOutputSymbols(node.getOutputSymbols()); + return makeLayoutFromOutputVariables(node.getOutputVariables()); } - private ImmutableMap makeLayoutFromOutputSymbols(List outputSymbols) + private ImmutableMap makeLayoutFromOutputVariables(List outputVariables) { - Builder outputMappings = ImmutableMap.builder(); + Builder outputMappings = ImmutableMap.builder(); int channel = 0; - for (Symbol symbol : outputSymbols) { - outputMappings.put(symbol, channel); + for (VariableReferenceExpression variable : outputVariables) { + outputMappings.put(variable, channel); channel++; } return outputMappings.build(); @@ -1412,19 +1402,19 @@ public PhysicalOperation visitIndexSource(IndexSourceNode node, LocalExecutionPl checkState(context.getIndexSourceContext().isPresent(), "Must be in an index source context"); IndexSourceContext indexSourceContext = context.getIndexSourceContext().get(); - SetMultimap indexLookupToProbeInput = indexSourceContext.getIndexLookupToProbeInput(); - checkState(indexLookupToProbeInput.keySet().equals(node.getLookupSymbols())); + SetMultimap indexLookupToProbeInput = indexSourceContext.getIndexLookupToProbeInput(); + checkState(indexLookupToProbeInput.keySet().equals(node.getLookupVariables())); // Finalize the symbol lookup layout for the index source - List lookupSymbolSchema = ImmutableList.copyOf(node.getLookupSymbols()); + List lookupVariableSchema = ImmutableList.copyOf(node.getLookupVariables()); // Identify how to remap the probe key Input to match the source index lookup layout ImmutableList.Builder remappedProbeKeyChannelsBuilder = ImmutableList.builder(); // Identify overlapping fields that can produce the same lookup symbol. // We will filter incoming keys to ensure that overlapping fields will have the same value. ImmutableList.Builder> overlappingFieldSetsBuilder = ImmutableList.builder(); - for (Symbol lookupSymbol : lookupSymbolSchema) { - Set potentialProbeInputs = indexLookupToProbeInput.get(lookupSymbol); + for (VariableReferenceExpression lookupVariable : node.getLookupVariables()) { + Set potentialProbeInputs = indexLookupToProbeInput.get(lookupVariable); checkState(!potentialProbeInputs.isEmpty(), "Must have at least one source from the probe input"); if (potentialProbeInputs.size() > 1) { overlappingFieldSetsBuilder.add(potentialProbeInputs.stream().collect(toImmutableSet())); @@ -1441,8 +1431,12 @@ public PhysicalOperation visitIndexSource(IndexSourceNode node, LocalExecutionPl }; // Declare the input and output schemas for the index and acquire the actual Index - List lookupSchema = Lists.transform(lookupSymbolSchema, forMap(node.getAssignments())); - List outputSchema = Lists.transform(node.getOutputSymbols(), forMap(node.getAssignments())); + List lookupSchema = lookupVariableSchema.stream().map(node.getAssignments()::get).collect(toImmutableList()); + List outputSchema = node.getAssignments().entrySet().stream() + .filter(entry -> node.getOutputVariables().contains(entry.getKey())) + .map(Map.Entry::getValue) + .collect(toImmutableList()); + ConnectorIndex index = indexManager.getIndex(session, node.getIndexHandle(), lookupSchema, outputSchema); OperatorFactory operatorFactory = new IndexSourceOperator.IndexSourceOperatorFactory(context.getNextOperatorId(), node.getId(), index, probeKeyNormalizer); @@ -1453,28 +1447,28 @@ public PhysicalOperation visitIndexSource(IndexSourceNode node, LocalExecutionPl * This method creates a mapping from each index source lookup symbol (directly applied to the index) * to the corresponding probe key Input */ - private SetMultimap mapIndexSourceLookupSymbolToProbeKeyInput(IndexJoinNode node, Map probeKeyLayout) + private SetMultimap mapIndexSourceLookupSymbolToProbeKeyInput(IndexJoinNode node, Map probeKeyLayout) { - Set indexJoinSymbols = node.getCriteria().stream() + Set indexJoinVariables = node.getCriteria().stream() .map(IndexJoinNode.EquiJoinClause::getIndex) .collect(toImmutableSet()); // Trace the index join symbols to the index source lookup symbols // Map: Index join symbol => Index source lookup symbol - Map indexKeyTrace = IndexJoinOptimizer.IndexKeyTracer.trace(node.getIndexSource(), indexJoinSymbols); + Map indexKeyTrace = IndexJoinOptimizer.IndexKeyTracer.trace(node.getIndexSource(), indexJoinVariables); // Map the index join symbols to the probe key Input - Multimap indexToProbeKeyInput = HashMultimap.create(); + Multimap indexToProbeKeyInput = HashMultimap.create(); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { indexToProbeKeyInput.put(clause.getIndex(), probeKeyLayout.get(clause.getProbe())); } // Create the mapping from index source look up symbol to probe key Input - ImmutableSetMultimap.Builder builder = ImmutableSetMultimap.builder(); - for (Map.Entry entry : indexKeyTrace.entrySet()) { - Symbol indexJoinSymbol = entry.getKey(); - Symbol indexLookupSymbol = entry.getValue(); - builder.putAll(indexLookupSymbol, indexToProbeKeyInput.get(indexJoinSymbol)); + ImmutableSetMultimap.Builder builder = ImmutableSetMultimap.builder(); + for (Map.Entry entry : indexKeyTrace.entrySet()) { + VariableReferenceExpression indexJoinVariable = entry.getKey(); + VariableReferenceExpression indexLookupVariable = entry.getValue(); + builder.putAll(indexJoinVariable, indexToProbeKeyInput.get(indexLookupVariable)); } return builder.build(); } @@ -1484,35 +1478,35 @@ public PhysicalOperation visitIndexJoin(IndexJoinNode node, LocalExecutionPlanCo { List clauses = node.getCriteria(); - List probeSymbols = Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getProbe); - List indexSymbols = Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getIndex); + List probeVariables = clauses.stream().map(IndexJoinNode.EquiJoinClause::getProbe).collect(toImmutableList()); + List indexVariables = clauses.stream().map(IndexJoinNode.EquiJoinClause::getIndex).collect(toImmutableList()); // Plan probe side PhysicalOperation probeSource = node.getProbeSource().accept(this, context); - List probeChannels = getChannelsForSymbols(probeSymbols, probeSource.getLayout()); - OptionalInt probeHashChannel = node.getProbeHashSymbol().map(channelGetter(probeSource)) + List probeChannels = getChannelsForVariables(probeVariables, probeSource.getLayout()); + OptionalInt probeHashChannel = node.getProbeHashVariable().map(variableChannelGetter(probeSource)) .map(OptionalInt::of).orElse(OptionalInt.empty()); // The probe key channels will be handed to the index according to probeSymbol order - Map probeKeyLayout = new HashMap<>(); - for (int i = 0; i < probeSymbols.size(); i++) { + Map probeKeyLayout = new HashMap<>(); + for (int i = 0; i < probeVariables.size(); i++) { // Duplicate symbols can appear and we only need to take take one of the Inputs - probeKeyLayout.put(probeSymbols.get(i), i); + probeKeyLayout.put(probeVariables.get(i), i); } // Plan the index source side - SetMultimap indexLookupToProbeInput = mapIndexSourceLookupSymbolToProbeKeyInput(node, probeKeyLayout); + SetMultimap indexLookupToProbeInput = mapIndexSourceLookupSymbolToProbeKeyInput(node, probeKeyLayout); LocalExecutionPlanContext indexContext = context.createIndexSourceSubContext(new IndexSourceContext(indexLookupToProbeInput)); PhysicalOperation indexSource = node.getIndexSource().accept(this, indexContext); - List indexOutputChannels = getChannelsForSymbols(indexSymbols, indexSource.getLayout()); - OptionalInt indexHashChannel = node.getIndexHashSymbol().map(channelGetter(indexSource)) + List indexOutputChannels = getChannelsForVariables(indexVariables, indexSource.getLayout()); + OptionalInt indexHashChannel = node.getIndexHashVariable().map(variableChannelGetter(indexSource)) .map(OptionalInt::of).orElse(OptionalInt.empty()); // Identify just the join keys/channels needed for lookup by the index source (does not have to use all of them). - Set indexSymbolsNeededBySource = IndexJoinOptimizer.IndexKeyTracer.trace(node.getIndexSource(), ImmutableSet.copyOf(indexSymbols)).keySet(); + Set indexVariablesNeededBySource = IndexJoinOptimizer.IndexKeyTracer.trace(node.getIndexSource(), ImmutableSet.copyOf(indexVariables)).keySet(); Set lookupSourceInputChannels = node.getCriteria().stream() - .filter(equiJoinClause -> indexSymbolsNeededBySource.contains(equiJoinClause.getIndex())) + .filter(equiJoinClause -> indexVariablesNeededBySource.contains(equiJoinClause.getIndex())) .map(IndexJoinNode.EquiJoinClause::getProbe) .map(probeKeyLayout::get) .collect(toImmutableSet()); @@ -1520,14 +1514,14 @@ public PhysicalOperation visitIndexJoin(IndexJoinNode node, LocalExecutionPlanCo Optional dynamicTupleFilterFactory = Optional.empty(); if (lookupSourceInputChannels.size() < probeKeyLayout.values().size()) { int[] nonLookupInputChannels = Ints.toArray(node.getCriteria().stream() - .filter(equiJoinClause -> !indexSymbolsNeededBySource.contains(equiJoinClause.getIndex())) + .filter(equiJoinClause -> !indexVariablesNeededBySource.contains(equiJoinClause.getIndex())) .map(IndexJoinNode.EquiJoinClause::getProbe) .map(probeKeyLayout::get) .collect(toImmutableList())); int[] nonLookupOutputChannels = Ints.toArray(node.getCriteria().stream() - .filter(equiJoinClause -> !indexSymbolsNeededBySource.contains(equiJoinClause.getIndex())) + .filter(equiJoinClause -> !indexVariablesNeededBySource.contains(equiJoinClause.getIndex())) .map(IndexJoinNode.EquiJoinClause::getIndex) - .map(indexSource.getLayout()::get) + .map(variable -> indexSource.getLayout().get(variable)) .collect(toImmutableList())); int filterOperatorId = indexContext.getNextOperatorId(); @@ -1571,13 +1565,13 @@ public PhysicalOperation visitIndexJoin(IndexJoinNode node, LocalExecutionPlanCo lifespan -> indexLookupSourceFactory, indexLookupSourceFactory.getOutputTypes()); - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); outputMappings.putAll(probeSource.getLayout()); // inputs from index side of the join are laid out following the input from the probe side, // so adjust the channel ids but keep the field layouts intact int offset = probeSource.getTypes().size(); - for (Map.Entry entry : indexSource.getLayout().entrySet()) { + for (Map.Entry entry : indexSource.getLayout().entrySet()) { Integer input = entry.getValue(); outputMappings.put(entry.getKey(), offset + input); } @@ -1605,15 +1599,15 @@ public PhysicalOperation visitJoin(JoinNode node, LocalExecutionPlanContext cont } List clauses = node.getCriteria(); - List leftSymbols = Lists.transform(clauses, JoinNode.EquiJoinClause::getLeft); - List rightSymbols = Lists.transform(clauses, JoinNode.EquiJoinClause::getRight); + List leftVariables = Lists.transform(clauses, JoinNode.EquiJoinClause::getLeft); + List rightVariables = Lists.transform(clauses, JoinNode.EquiJoinClause::getRight); switch (node.getType()) { case INNER: case LEFT: case RIGHT: case FULL: - return createLookupJoin(node, node.getLeft(), leftSymbols, node.getLeftHashSymbol(), node.getRight(), rightSymbols, node.getRightHashSymbol(), context); + return createLookupJoin(node, node.getLeft(), leftVariables, node.getLeftHashVariable(), node.getRight(), rightVariables, node.getRightHashVariable(), context); default: throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); } @@ -1638,7 +1632,7 @@ public PhysicalOperation visitSpatialJoin(SpatialJoinNode node, LocalExecutionPl if (functionMetadata.getOperatorType().get() == OperatorType.LESS_THAN || functionMetadata.getOperatorType().get() == OperatorType.LESS_THAN_OR_EQUAL) { // ST_Distance(a, b) <= r RowExpression radius = spatialComparison.getArguments().get(1); - if (radius instanceof VariableReferenceExpression && node.getRight().getOutputSymbols().contains(new Symbol(((VariableReferenceExpression) radius).getName()))) { + if (radius instanceof VariableReferenceExpression && node.getRight().getOutputVariables().contains(radius)) { CallExpression spatialFunction = (CallExpression) spatialComparison.getArguments().get(0); Optional operation = tryCreateSpatialJoin( context, @@ -1672,35 +1666,35 @@ private Optional tryCreateSpatialJoin( return Optional.empty(); } - VariableReferenceExpression firstSymbol = (VariableReferenceExpression) arguments.get(0); - VariableReferenceExpression secondSymbol = (VariableReferenceExpression) arguments.get(1); + VariableReferenceExpression firstVariable = (VariableReferenceExpression) arguments.get(0); + VariableReferenceExpression secondVariable = (VariableReferenceExpression) arguments.get(1); PlanNode probeNode = node.getLeft(); - Set probeSymbols = getSymbolReferences(probeNode.getOutputSymbols()); + Set probeSymbols = getSymbolReferences(probeNode.getOutputVariables()); PlanNode buildNode = node.getRight(); - Set buildSymbols = getSymbolReferences(buildNode.getOutputSymbols()); + Set buildSymbols = getSymbolReferences(buildNode.getOutputVariables()); - if (probeSymbols.contains(new SymbolReference(firstSymbol.getName())) && buildSymbols.contains(new SymbolReference(secondSymbol.getName()))) { + if (probeSymbols.contains(new SymbolReference(firstVariable.getName())) && buildSymbols.contains(new SymbolReference(secondVariable.getName()))) { return Optional.of(createSpatialLookupJoin( node, probeNode, - new Symbol(firstSymbol.getName()), + firstVariable, buildNode, - new Symbol(secondSymbol.getName()), - radius.map(r -> new Symbol(r.getName())), + secondVariable, + radius, spatialTest(spatialFunction, true, comparisonOperator), filterExpression, context)); } - else if (probeSymbols.contains(new SymbolReference(secondSymbol.getName())) && buildSymbols.contains(new SymbolReference(firstSymbol.getName()))) { + else if (probeSymbols.contains(new SymbolReference(secondVariable.getName())) && buildSymbols.contains(new SymbolReference(firstVariable.getName()))) { return Optional.of(createSpatialLookupJoin( node, probeNode, - new Symbol(secondSymbol.getName()), + secondVariable, buildNode, - new Symbol(firstSymbol.getName()), - radius.map(r -> new Symbol(r.getName())), + firstVariable, + radius, spatialTest(spatialFunction, false, comparisonOperator), filterExpression, context)); @@ -1749,9 +1743,9 @@ else if (comparisonOperator.get() == OperatorType.LESS_THAN_OR_EQUAL) { } } - private Set getSymbolReferences(Collection symbols) + private Set getSymbolReferences(Collection variables) { - return symbols.stream().map(Symbol::toSymbolReference).collect(toImmutableSet()); + return variables.stream().map(VariableReferenceExpression::getName).map(SymbolReference::new).collect(toImmutableSet()); } private PhysicalOperation createNestedLoopJoin(JoinNode node, LocalExecutionPlanContext context) @@ -1788,13 +1782,13 @@ private PhysicalOperation createNestedLoopJoin(JoinNode node, LocalExecutionPlan buildContext.getDriverInstanceCount(), buildSource.getPipelineExecutionStrategy()); - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); outputMappings.putAll(probeSource.getLayout()); // inputs from build side of the join are laid out following the input from the probe side, // so adjust the channel ids but keep the field layouts intact int offset = probeSource.getTypes().size(); - for (Map.Entry entry : buildSource.getLayout().entrySet()) { + for (Map.Entry entry : buildSource.getLayout().entrySet()) { outputMappings.put(entry.getKey(), offset + entry.getValue()); } @@ -1805,10 +1799,10 @@ private PhysicalOperation createNestedLoopJoin(JoinNode node, LocalExecutionPlan private PhysicalOperation createSpatialLookupJoin( SpatialJoinNode node, PlanNode probeNode, - Symbol probeSymbol, + VariableReferenceExpression probeVariable, PlanNode buildNode, - Symbol buildSymbol, - Optional radiusSymbol, + VariableReferenceExpression buildVariable, + Optional radiusVariable, SpatialPredicate spatialRelationshipTest, Optional joinFilter, LocalExecutionPlanContext context) @@ -1819,20 +1813,19 @@ private PhysicalOperation createSpatialLookupJoin( // Plan build PagesSpatialIndexFactory pagesSpatialIndexFactory = createPagesSpatialIndexFactory(node, buildNode, - buildSymbol, - radiusSymbol, + buildVariable, + radiusVariable, probeSource.getLayout(), spatialRelationshipTest, joinFilter, context); - OperatorFactory operator = createSpatialLookupJoin(node, probeNode, probeSource, probeSymbol, pagesSpatialIndexFactory, context); + OperatorFactory operator = createSpatialLookupJoin(node, probeNode, probeSource, probeVariable, pagesSpatialIndexFactory, context); - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); - List outputSymbols = node.getOutputSymbols(); - for (int i = 0; i < outputSymbols.size(); i++) { - Symbol symbol = outputSymbols.get(i); - outputMappings.put(symbol, i); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + List outputVariables = node.getOutputVariables(); + for (int i = 0; i < outputVariables.size(); i++) { + outputMappings.put(outputVariables.get(i), i); } return new PhysicalOperation(operator, outputMappings.build(), context, probeSource); @@ -1841,19 +1834,19 @@ private PhysicalOperation createSpatialLookupJoin( private OperatorFactory createSpatialLookupJoin(SpatialJoinNode node, PlanNode probeNode, PhysicalOperation probeSource, - Symbol probeSymbol, + VariableReferenceExpression probeVariable, PagesSpatialIndexFactory pagesSpatialIndexFactory, LocalExecutionPlanContext context) { List probeTypes = probeSource.getTypes(); - List probeOutputSymbols = node.getOutputSymbols().stream() - .filter(symbol -> probeNode.getOutputSymbols().contains(symbol)) + List probeOutputVariables = node.getOutputVariables().stream() + .filter(probeNode.getOutputVariables()::contains) .collect(toImmutableList()); - List probeOutputChannels = ImmutableList.copyOf(getChannelsForSymbols(probeOutputSymbols, probeSource.getLayout())); - Function probeChannelGetter = channelGetter(probeSource); - int probeChannel = probeChannelGetter.apply(probeSymbol); + List probeOutputChannels = ImmutableList.copyOf(getChannelsForVariables(probeOutputVariables, probeSource.getLayout())); + Function probeChannelGetter = variableChannelGetter(probeSource); + int probeChannel = probeChannelGetter.apply(probeVariable); - Optional partitionChannel = node.getLeftPartitionSymbol().map(probeChannelGetter::apply); + Optional partitionChannel = node.getLeftPartitionVariable().map(probeChannelGetter); return new SpatialJoinOperatorFactory( context.getNextOperatorId(), @@ -1869,23 +1862,23 @@ private OperatorFactory createSpatialLookupJoin(SpatialJoinNode node, private PagesSpatialIndexFactory createPagesSpatialIndexFactory( SpatialJoinNode node, PlanNode buildNode, - Symbol buildSymbol, - Optional radiusSymbol, - Map probeLayout, + VariableReferenceExpression buildVariable, + Optional radiusVariable, + Map probeLayout, SpatialPredicate spatialRelationshipTest, Optional joinFilter, LocalExecutionPlanContext context) { LocalExecutionPlanContext buildContext = context.createSubContext(); PhysicalOperation buildSource = buildNode.accept(this, buildContext); - List buildOutputSymbols = node.getOutputSymbols().stream() - .filter(symbol -> buildNode.getOutputSymbols().contains(symbol)) + List buildOutputVariables = node.getOutputVariables().stream() + .filter(buildNode.getOutputVariables()::contains) .collect(toImmutableList()); - Map buildLayout = buildSource.getLayout(); - List buildOutputChannels = ImmutableList.copyOf(getChannelsForSymbols(buildOutputSymbols, buildLayout)); - Function buildChannelGetter = channelGetter(buildSource); - Integer buildChannel = buildChannelGetter.apply(buildSymbol); - Optional radiusChannel = radiusSymbol.map(buildChannelGetter::apply); + Map buildLayout = buildSource.getLayout(); + List buildOutputChannels = ImmutableList.copyOf(getChannelsForVariables(buildOutputVariables, buildLayout)); + Function buildChannelGetter = variableChannelGetter(buildSource); + Integer buildChannel = buildChannelGetter.apply(buildVariable); + Optional radiusChannel = radiusVariable.map(buildChannelGetter::apply); Optional filterFunctionFactory = joinFilter .map(filterExpression -> compileJoinFilterFunction( @@ -1893,7 +1886,7 @@ private PagesSpatialIndexFactory createPagesSpatialIndexFactory( probeLayout, buildLayout)); - Optional partitionChannel = node.getRightPartitionSymbol().map(buildChannelGetter::apply); + Optional partitionChannel = node.getRightPartitionVariable().map(buildChannelGetter); SpatialIndexBuilderOperatorFactory builderOperatorFactory = new SpatialIndexBuilderOperatorFactory( buildContext.getNextOperatorId(), @@ -1924,11 +1917,11 @@ private PagesSpatialIndexFactory createPagesSpatialIndexFactory( private PhysicalOperation createLookupJoin(JoinNode node, PlanNode probeNode, - List probeSymbols, - Optional probeHashSymbol, + List probeVariables, + Optional probeHashVariable, PlanNode buildNode, - List buildSymbols, - Optional buildHashSymbol, + List buildVariables, + Optional buildHashVariable, LocalExecutionPlanContext context) { // Plan probe @@ -1936,15 +1929,14 @@ private PhysicalOperation createLookupJoin(JoinNode node, // Plan build JoinBridgeManager lookupSourceFactory = - createLookupSourceFactory(node, buildNode, buildSymbols, buildHashSymbol, probeSource, context); + createLookupSourceFactory(node, buildNode, buildVariables, buildHashVariable, probeSource, context); - OperatorFactory operator = createLookupJoin(node, probeSource, probeSymbols, probeHashSymbol, lookupSourceFactory, context); + OperatorFactory operator = createLookupJoin(node, probeSource, probeVariables, probeHashVariable, lookupSourceFactory, context); - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); - List outputSymbols = node.getOutputSymbols(); - for (int i = 0; i < outputSymbols.size(); i++) { - Symbol symbol = outputSymbols.get(i); - outputMappings.put(symbol, i); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + List outputVariables = node.getOutputVariables(); + for (int i = 0; i < outputVariables.size(); i++) { + outputMappings.put(outputVariables.get(i), i); } return new PhysicalOperation(operator, outputMappings.build(), context, probeSource); @@ -1953,8 +1945,8 @@ private PhysicalOperation createLookupJoin(JoinNode node, private JoinBridgeManager createLookupSourceFactory( JoinNode node, PlanNode buildNode, - List buildSymbols, - Optional buildHashSymbol, + List buildVariables, + Optional buildHashVariable, PhysicalOperation probeSource, LocalExecutionPlanContext context) { @@ -1967,12 +1959,12 @@ private JoinBridgeManager createLookupSourceFact "Build execution is GROUPED_EXECUTION. Probe execution is expected be GROUPED_EXECUTION, but is UNGROUPED_EXECUTION."); } - List buildOutputSymbols = node.getOutputSymbols().stream() - .filter(symbol -> node.getRight().getOutputSymbols().contains(symbol)) + List buildOutputVariables = node.getOutputVariables().stream() + .filter(node.getRight().getOutputVariables()::contains) .collect(toImmutableList()); - List buildOutputChannels = ImmutableList.copyOf(getChannelsForSymbols(buildOutputSymbols, buildSource.getLayout())); - List buildChannels = ImmutableList.copyOf(getChannelsForSymbols(buildSymbols, buildSource.getLayout())); - OptionalInt buildHashChannel = buildHashSymbol.map(channelGetter(buildSource)) + List buildOutputChannels = ImmutableList.copyOf(getChannelsForVariables(buildOutputVariables, buildSource.getLayout())); + List buildChannels = ImmutableList.copyOf(getChannelsForVariables(buildVariables, buildSource.getLayout())); + OptionalInt buildHashChannel = buildHashVariable.map(variableChannelGetter(buildSource)) .map(OptionalInt::of).orElse(OptionalInt.empty()); boolean spillEnabled = isSpillEnabled(context.getSession()); @@ -2051,19 +2043,19 @@ private JoinBridgeManager createLookupSourceFact private JoinFilterFunctionFactory compileJoinFilterFunction( RowExpression filterExpression, - Map probeLayout, - Map buildLayout) + Map probeLayout, + Map buildLayout) { - Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); + Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); return joinFilterFunctionCompiler.compileJoinFilterFunction(bindChannels(filterExpression, joinSourcesLayout), buildLayout.size()); } private int sortExpressionAsSortChannel( RowExpression sortExpression, - Map probeLayout, - Map buildLayout) + Map probeLayout, + Map buildLayout) { - Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); + Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); RowExpression rewrittenSortExpression = bindChannels(sortExpression, joinSourcesLayout); checkArgument(rewrittenSortExpression instanceof InputReferenceExpression, "Unsupported expression type [%s]", rewrittenSortExpression); return ((InputReferenceExpression) rewrittenSortExpression).getField(); @@ -2072,18 +2064,18 @@ private int sortExpressionAsSortChannel( private OperatorFactory createLookupJoin( JoinNode node, PhysicalOperation probeSource, - List probeSymbols, - Optional probeHashSymbol, + List probeVariables, + Optional probeHashVariable, JoinBridgeManager lookupSourceFactoryManager, LocalExecutionPlanContext context) { List probeTypes = probeSource.getTypes(); - List probeOutputSymbols = node.getOutputSymbols().stream() - .filter(symbol -> node.getLeft().getOutputSymbols().contains(symbol)) + List probeOutputVariables = node.getOutputVariables().stream() + .filter(node.getLeft().getOutputVariables()::contains) .collect(toImmutableList()); - List probeOutputChannels = ImmutableList.copyOf(getChannelsForSymbols(probeOutputSymbols, probeSource.getLayout())); - List probeJoinChannels = ImmutableList.copyOf(getChannelsForSymbols(probeSymbols, probeSource.getLayout())); - OptionalInt probeHashChannel = probeHashSymbol.map(channelGetter(probeSource)) + List probeOutputChannels = ImmutableList.copyOf(getChannelsForVariables(probeOutputVariables, probeSource.getLayout())); + List probeJoinChannels = ImmutableList.copyOf(getChannelsForVariables(probeVariables, probeSource.getLayout())); + OptionalInt probeHashChannel = probeHashVariable.map(variableChannelGetter(probeSource)) .map(OptionalInt::of).orElse(OptionalInt.empty()); OptionalInt totalOperatorsCount = getJoinOperatorsCountForSpill(context, session); @@ -2110,11 +2102,11 @@ private OptionalInt getJoinOperatorsCountForSpill(LocalExecutionPlanContext cont return driverInstanceCount; } - private Map createJoinSourcesLayout(Map lookupSourceLayout, Map probeSourceLayout) + private Map createJoinSourcesLayout(Map lookupSourceLayout, Map probeSourceLayout) { - Builder joinSourcesLayout = ImmutableMap.builder(); + Builder joinSourcesLayout = ImmutableMap.builder(); joinSourcesLayout.putAll(lookupSourceLayout); - for (Map.Entry probeLayoutEntry : probeSourceLayout.entrySet()) { + for (Map.Entry probeLayoutEntry : probeSourceLayout.entrySet()) { joinSourcesLayout.put(probeLayoutEntry.getKey(), probeLayoutEntry.getValue() + lookupSourceLayout.size()); } return joinSourcesLayout.build(); @@ -2132,10 +2124,10 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont checkState(buildSource.getPipelineExecutionStrategy() == probeSource.getPipelineExecutionStrategy(), "build and probe have different pipelineExecutionStrategy"); checkArgument(buildContext.getDriverInstanceCount().orElse(1) == 1, "Expected local execution to not be parallel"); - int probeChannel = probeSource.getLayout().get(node.getSourceJoinSymbol()); - int buildChannel = buildSource.getLayout().get(node.getFilteringSourceJoinSymbol()); + int probeChannel = probeSource.getLayout().get(node.getSourceJoinVariable()); + int buildChannel = buildSource.getLayout().get(node.getFilteringSourceJoinVariable()); - Optional buildHashChannel = node.getFilteringSourceHashSymbol().map(channelGetter(buildSource)); + Optional buildHashChannel = node.getFilteringSourceHashVariable().map(variableChannelGetter(buildSource)); SetBuilderOperatorFactory setBuilderOperatorFactory = new SetBuilderOperatorFactory( buildContext.getNextOperatorId(), @@ -2157,7 +2149,7 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont buildSource.getPipelineExecutionStrategy()); // Source channels are always laid out first, followed by the boolean output symbol - Map outputMappings = ImmutableMap.builder() + Map outputMappings = ImmutableMap.builder() .putAll(probeSource.getLayout()) .put(node.getSemiJoinOutput(), probeSource.getLayout().size()) .build(); @@ -2180,14 +2172,14 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl // serialize writes by forcing data through a single writer PhysicalOperation source = node.getSource().accept(this, context); - ImmutableMap.Builder outputMapping = ImmutableMap.builder(); - outputMapping.put(node.getOutputSymbols().get(ROW_COUNT_CHANNEL), ROW_COUNT_CHANNEL); - outputMapping.put(node.getOutputSymbols().get(FRAGMENT_CHANNEL), FRAGMENT_CHANNEL); - outputMapping.put(node.getOutputSymbols().get(CONTEXT_CHANNEL), CONTEXT_CHANNEL); + ImmutableMap.Builder outputMapping = ImmutableMap.builder(); + outputMapping.put(node.getRowCountVariable(), ROW_COUNT_CHANNEL); + outputMapping.put(node.getFragmentVariable(), FRAGMENT_CHANNEL); + outputMapping.put(node.getTableCommitContextVariable(), CONTEXT_CHANNEL); OperatorFactory statisticsAggregation = node.getStatisticsAggregation().map(aggregation -> { - List groupingSymbols = aggregation.getGroupingSymbols(); - if (groupingSymbols.isEmpty()) { + List groupingVariables = aggregation.getGroupingVariables(); + if (groupingVariables.isEmpty()) { return createAggregationOperatorFactory( node.getId(), aggregation.getAggregations(), @@ -2202,7 +2194,7 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl node.getId(), aggregation.getAggregations(), ImmutableSet.of(), - groupingSymbols, + groupingVariables, PARTIAL, Optional.empty(), Optional.empty(), @@ -2225,7 +2217,7 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl }).orElse(new DevNullOperatorFactory(context.getNextOperatorId(), node.getId())); List inputChannels = node.getColumns().stream() - .map(source::symbolToChannel) + .map(source::variableToChannel) .collect(toImmutableList()); OperatorFactory operatorFactory = new TableWriterOperatorFactory( @@ -2236,7 +2228,7 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl inputChannels, session, statisticsAggregation, - getSymbolTypes(node.getOutputSymbols(), context.getTypes()), + getVariableTypes(node.getOutputVariables()), tableCommitContextCodec, stageExecutionDescriptor.isRecoverableGroupedExecution()); @@ -2248,7 +2240,7 @@ public PhysicalOperation visitStatisticsWriterNode(StatisticsWriterNode node, Lo { PhysicalOperation source = node.getSource().accept(this, context); - StatisticAggregationsDescriptor descriptor = node.getDescriptor().map(symbol -> source.getLayout().get(symbol)); + StatisticAggregationsDescriptor descriptor = node.getDescriptor().map(source.getLayout()::get); OperatorFactory operatorFactory = new StatisticsWriterOperatorFactory( context.getNextOperatorId(), @@ -2264,7 +2256,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl { PhysicalOperation source = node.getSource().accept(this, context); - ImmutableMap.Builder outputMapping = ImmutableMap.builder(); + ImmutableMap.Builder outputMapping = ImmutableMap.builder(); OperatorFactory statisticsAggregation = node.getStatisticsAggregation().map(aggregation -> { List groupingSymbols = aggregation.getGroupingSymbols(); @@ -2283,7 +2275,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl node.getId(), aggregation.getAggregations(), ImmutableSet.of(), - groupingSymbols, + aggregation.getGroupingVariables(), FINAL, Optional.empty(), Optional.empty(), @@ -2301,7 +2293,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl true); }).orElse(new DevNullOperatorFactory(context.getNextOperatorId(), node.getId())); - Map aggregationOutput = outputMapping.build(); + Map aggregationOutput = outputMapping.build(); StatisticAggregationsDescriptor descriptor = node.getStatisticsAggregationDescriptor() .map(desc -> desc.map(aggregationOutput::get)) .orElse(StatisticAggregationsDescriptor.empty()); @@ -2315,7 +2307,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl descriptor, session, tableCommitContextCodec); - Map layout = ImmutableMap.of(node.getOutputSymbols().get(0), 0); + Map layout = ImmutableMap.of(node.getOutputVariables().get(0), 0); return new PhysicalOperation(operatorFactory, layout, context, source); } @@ -2327,9 +2319,9 @@ public PhysicalOperation visitDelete(DeleteNode node, LocalExecutionPlanContext OperatorFactory operatorFactory = new DeleteOperatorFactory(context.getNextOperatorId(), node.getId(), source.getLayout().get(node.getRowId()), tableCommitContextCodec); - Map layout = ImmutableMap.builder() - .put(node.getOutputSymbols().get(0), 0) - .put(node.getOutputSymbols().get(1), 1) + Map layout = ImmutableMap.builder() + .put(node.getOutputVariables().get(0), 0) + .put(node.getOutputVariables().get(1), 1) .build(); return new PhysicalOperation(operatorFactory, layout, context, source); @@ -2405,7 +2397,7 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan maxLocalExchangeBufferSize); List operatorFactories = new ArrayList<>(source.getOperatorFactories()); - List expectedLayout = node.getInputs().get(0); + List expectedLayout = node.getInputs().get(0); Function pagePreprocessor = enforceLayoutProcessor(expectedLayout, source.getLayout()); operatorFactories.add(new LocalExchangeSinkOperatorFactory( exchangeFactory, @@ -2418,8 +2410,8 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan context.setInputDriver(false); OrderingScheme orderingScheme = node.getOrderingScheme().get(); - ImmutableMap layout = makeLayout(node); - List sortChannels = getChannelsForSymbols(orderingScheme.getOrderBy(), layout); + ImmutableMap layout = makeLayout(node); + List sortChannels = getChannelsForVariables(orderingScheme.getOrderBy(), layout); List orderings = orderingScheme.getOrderingList(); OperatorFactory operatorFactory = new LocalMergeSourceOperatorFactory( context.getNextOperatorId(), @@ -2449,10 +2441,10 @@ else if (context.getDriverInstanceCount().isPresent()) { List types = getSourceOperatorTypes(node, context.getTypes()); List channels = node.getPartitioningScheme().getPartitioning().getArguments().stream() - .map(argument -> node.getOutputSymbols().indexOf(argument.getColumn())) + .map(argument -> node.getOutputVariables().indexOf(argument.getVariableReference())) .collect(toImmutableList()); Optional hashChannel = node.getPartitioningScheme().getHashColumn() - .map(symbol -> node.getOutputSymbols().indexOf(symbol)); + .map(variable -> node.getOutputVariables().indexOf(variable)); PipelineExecutionStrategy exchangeSourcePipelineExecutionStrategy = GROUPED_EXECUTION; List driverFactoryParametersList = new ArrayList<>(); @@ -2481,7 +2473,7 @@ else if (context.getDriverInstanceCount().isPresent()) { PhysicalOperation source = driverFactoryParameters.getSource(); LocalExecutionPlanContext subContext = driverFactoryParameters.getSubContext(); - List expectedLayout = node.getInputs().get(i); + List expectedLayout = node.getInputs().get(i); Function pagePreprocessor = enforceLayoutProcessor(expectedLayout, source.getLayout()); List operatorFactories = new ArrayList<>(source.getOperatorFactories()); @@ -2517,28 +2509,30 @@ protected PhysicalOperation visitPlan(PlanNode node, LocalExecutionPlanContext c private List getSourceOperatorTypes(PlanNode node, TypeProvider types) { - return getSymbolTypes(node.getOutputSymbols(), types); + return getVariableTypes(node.getOutputVariables()); } - private List getSymbolTypes(List symbols, TypeProvider types) + private List getVariableTypes(List variables) { - return symbols.stream() - .map(types::get) + return variables.stream() + .map(VariableReferenceExpression::getType) .collect(toImmutableList()); } private AccumulatorFactory buildAccumulatorFactory( PhysicalOperation source, - Aggregation aggregation) + Aggregation aggregation, + TypeProvider types) { FunctionManager functionManager = metadata.getFunctionManager(); InternalAggregationFunction internalAggregationFunction = functionManager.getAggregateFunctionImplementation(aggregation.getFunctionHandle()); List valueChannels = new ArrayList<>(); - for (Expression argument : aggregation.getArguments()) { + for (int i = 0; i < aggregation.getArguments().size(); i++) { + Expression argument = aggregation.getArguments().get(i); if (!(argument instanceof LambdaExpression)) { - Symbol argumentSymbol = Symbol.from(argument); - valueChannels.add(source.getLayout().get(argumentSymbol)); + VariableReferenceExpression argumentVariable = new VariableReferenceExpression(Symbol.from(argument).getName(), types.get(Symbol.from(argument))); + valueChannels.add(source.getLayout().get(argumentVariable)); } } @@ -2610,7 +2604,7 @@ private AccumulatorFactory buildAccumulatorFactory( Optional maskChannel = aggregation.getMask().map(value -> source.getLayout().get(value)); List sortOrders = ImmutableList.of(); - List sortKeys = ImmutableList.of(); + List sortKeys = ImmutableList.of(); if (aggregation.getOrderBy().isPresent()) { OrderingScheme orderBy = aggregation.getOrderBy().get(); sortKeys = orderBy.getOrderBy(); @@ -2621,7 +2615,7 @@ private AccumulatorFactory buildAccumulatorFactory( valueChannels, maskChannel, source.getTypes(), - getChannelsForSymbols(sortKeys, source.getLayout()), + getChannelsForVariables(sortKeys, source.getLayout()), sortOrders, pagesIndexFactory, aggregation.isDistinct(), @@ -2632,7 +2626,7 @@ private AccumulatorFactory buildAccumulatorFactory( private PhysicalOperation planGlobalAggregation(AggregationNode node, PhysicalOperation source, LocalExecutionPlanContext context) { - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); AggregationOperatorFactory operatorFactory = createAggregationOperatorFactory( node.getId(), node.getAggregations(), @@ -2647,21 +2641,21 @@ private PhysicalOperation planGlobalAggregation(AggregationNode node, PhysicalOp private AggregationOperatorFactory createAggregationOperatorFactory( PlanNodeId planNodeId, - Map aggregations, + Map aggregations, Step step, int startOutputChannel, - ImmutableMap.Builder outputMappings, + ImmutableMap.Builder outputMappings, PhysicalOperation source, LocalExecutionPlanContext context, boolean useSystemMemory) { int outputChannel = startOutputChannel; ImmutableList.Builder accumulatorFactories = ImmutableList.builder(); - for (Map.Entry entry : aggregations.entrySet()) { - Symbol symbol = entry.getKey(); + for (Map.Entry entry : aggregations.entrySet()) { + VariableReferenceExpression variable = entry.getKey(); Aggregation aggregation = entry.getValue(); - accumulatorFactories.add(buildAccumulatorFactory(source, aggregation)); - outputMappings.put(symbol, outputChannel); // one aggregation per channel + accumulatorFactories.add(buildAccumulatorFactory(source, aggregation, context.getTypes())); + outputMappings.put(variable, outputChannel); // one aggregation per channel outputChannel++; } return new AggregationOperatorFactory(context.getNextOperatorId(), planNodeId, step, accumulatorFactories.build(), useSystemMemory); @@ -2674,15 +2668,15 @@ private PhysicalOperation planGroupByAggregation( DataSize unspillMemoryLimit, LocalExecutionPlanContext context) { - ImmutableMap.Builder mappings = ImmutableMap.builder(); + ImmutableMap.Builder mappings = ImmutableMap.builder(); OperatorFactory operatorFactory = createHashAggregationOperatorFactory( node.getId(), node.getAggregations(), node.getGlobalGroupingSets(), node.getGroupingKeys(), node.getStep(), - node.getHashSymbol(), - node.getGroupIdSymbol(), + node.getHashVariable(), + node.getGroupIdVariable(), source, node.hasDefaultOutput(), spillEnabled, @@ -2699,12 +2693,12 @@ private PhysicalOperation planGroupByAggregation( private OperatorFactory createHashAggregationOperatorFactory( PlanNodeId planNodeId, - Map aggregations, + Map aggregations, Set globalGroupingSets, - List groupBySymbols, + List groupbyVariables, Step step, - Optional hashSymbol, - Optional groupIdSymbol, + Optional hashVariable, + Optional groupIdVariable, PhysicalOperation source, boolean hasDefaultOutput, boolean spillEnabled, @@ -2712,44 +2706,44 @@ private OperatorFactory createHashAggregationOperatorFactory( DataSize unspillMemoryLimit, LocalExecutionPlanContext context, int startOutputChannel, - ImmutableMap.Builder outputMappings, + ImmutableMap.Builder outputMappings, int expectedGroups, Optional maxPartialAggregationMemorySize, boolean useSystemMemory) { - List aggregationOutputSymbols = new ArrayList<>(); + List aggregationOutputSymbols = new ArrayList<>(); List accumulatorFactories = new ArrayList<>(); - for (Map.Entry entry : aggregations.entrySet()) { - Symbol symbol = entry.getKey(); + for (Map.Entry entry : aggregations.entrySet()) { + VariableReferenceExpression variable = entry.getKey(); Aggregation aggregation = entry.getValue(); - accumulatorFactories.add(buildAccumulatorFactory(source, aggregation)); - aggregationOutputSymbols.add(symbol); + accumulatorFactories.add(buildAccumulatorFactory(source, aggregation, context.getTypes())); + aggregationOutputSymbols.add(variable); } // add group-by key fields each in a separate channel int channel = startOutputChannel; Optional groupIdChannel = Optional.empty(); - for (Symbol symbol : groupBySymbols) { - outputMappings.put(symbol, channel); - if (groupIdSymbol.isPresent() && groupIdSymbol.get().equals(symbol)) { + for (VariableReferenceExpression variable : groupbyVariables) { + outputMappings.put(variable, channel); + if (groupIdVariable.isPresent() && groupIdVariable.get().equals(variable)) { groupIdChannel = Optional.of(channel); } channel++; } // hashChannel follows the group by channels - if (hashSymbol.isPresent()) { - outputMappings.put(hashSymbol.get(), channel++); + if (hashVariable.isPresent()) { + outputMappings.put(hashVariable.get(), channel++); } // aggregations go in following channels - for (Symbol symbol : aggregationOutputSymbols) { - outputMappings.put(symbol, channel); + for (VariableReferenceExpression variable : aggregationOutputSymbols) { + outputMappings.put(variable, channel); channel++; } - List groupByChannels = getChannelsForSymbols(groupBySymbols, source.getLayout()); + List groupByChannels = getChannelsForVariables(groupbyVariables, source.getLayout()); List groupByTypes = groupByChannels.stream() .map(entry -> source.getTypes().get(entry)) .collect(toImmutableList()); @@ -2766,7 +2760,7 @@ private OperatorFactory createHashAggregationOperatorFactory( joinCompiler); } else { - Optional hashChannel = hashSymbol.map(channelGetter(source)); + Optional hashChannel = hashVariable.map(variableChannelGetter(source)); return new HashAggregationOperatorFactory( context.getNextOperatorId(), planNodeId, @@ -2833,10 +2827,10 @@ else if (target instanceof InsertHandle) { }; } - private static Function enforceLayoutProcessor(List expectedLayout, Map inputLayout) + private static Function enforceLayoutProcessor(List expectedLayout, Map inputLayout) { int[] channels = expectedLayout.stream() - .peek(symbol -> checkArgument(inputLayout.containsKey(symbol), "channel not found for symbol: %s", symbol)) + .peek(variable -> checkArgument(inputLayout.containsKey(variable), "channel not found for variable: %s", variable)) .mapToInt(inputLayout::get) .toArray(); @@ -2857,7 +2851,17 @@ private static List getChannelsForSymbols(List symbols, Map channelGetter(PhysicalOperation source) + private static List getChannelsForVariables(Collection variables, Map layout) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (VariableReferenceExpression variable : variables) { + checkArgument(layout.containsKey(variable)); + builder.add(layout.get(variable)); + } + return builder.build(); + } + + private static Function variableChannelGetter(PhysicalOperation source) { return input -> { checkArgument(source.getLayout().containsKey(input)); @@ -2871,24 +2875,24 @@ private static Function channelGetter(PhysicalOperation source) private static class PhysicalOperation { private final List operatorFactories; - private final Map layout; + private final Map layout; private final List types; private final PipelineExecutionStrategy pipelineExecutionStrategy; - public PhysicalOperation(OperatorFactory operatorFactory, Map layout, LocalExecutionPlanContext context, PipelineExecutionStrategy pipelineExecutionStrategy) + public PhysicalOperation(OperatorFactory operatorFactory, Map layout, LocalExecutionPlanContext context, PipelineExecutionStrategy pipelineExecutionStrategy) { this(operatorFactory, layout, context, Optional.empty(), pipelineExecutionStrategy); } - public PhysicalOperation(OperatorFactory operatorFactory, Map layout, LocalExecutionPlanContext context, PhysicalOperation source) + public PhysicalOperation(OperatorFactory operatorFactory, Map layout, LocalExecutionPlanContext context, PhysicalOperation source) { this(operatorFactory, layout, context, Optional.of(requireNonNull(source, "source is null")), source.getPipelineExecutionStrategy()); } private PhysicalOperation( OperatorFactory operatorFactory, - Map layout, + Map layout, LocalExecutionPlanContext context, Optional source, PipelineExecutionStrategy pipelineExecutionStrategy) @@ -2904,26 +2908,26 @@ private PhysicalOperation( .add(operatorFactory) .build(); this.layout = ImmutableMap.copyOf(layout); - this.types = toTypes(layout, context); + this.types = toTypes(layout); this.pipelineExecutionStrategy = pipelineExecutionStrategy; } - private static List toTypes(Map layout, LocalExecutionPlanContext context) + private static List toTypes(Map layout) { // verify layout covers all values int channelCount = layout.values().stream().mapToInt(Integer::intValue).max().orElse(-1) + 1; checkArgument( layout.size() == channelCount && ImmutableSet.copyOf(layout.values()).containsAll(ContiguousSet.create(closedOpen(0, channelCount), integers())), "Layout does not have a symbol for every output channel: %s", layout); - Map channelLayout = ImmutableBiMap.copyOf(layout).inverse(); + Map channelLayout = ImmutableBiMap.copyOf(layout).inverse(); return range(0, channelCount) .mapToObj(channelLayout::get) - .map(context.getTypes()::get) + .map(VariableReferenceExpression::getType) .collect(toImmutableList()); } - public int symbolToChannel(Symbol input) + private int variableToChannel(VariableReferenceExpression input) { checkArgument(layout.containsKey(input)); return layout.get(input); @@ -2934,7 +2938,7 @@ public List getTypes() return types; } - public Map getLayout() + public Map getLayout() { return layout; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 1bc0f5f834f16..fe936a6c90f73 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -33,6 +33,7 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.statistics.TableStatisticsMetadata; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.Analysis; @@ -159,7 +160,7 @@ public LogicalPlanner( this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); - this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(session, symbolAllocator, metadata); + this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(symbolAllocator, metadata); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); @@ -209,11 +210,12 @@ public PlanNode planStatement(Analysis analysis, Statement statement) if (statement instanceof CreateTableAsSelect && analysis.isCreateTableAsSelectNoOp()) { checkState(analysis.getCreateTableDestination().isPresent(), "Table destination is missing"); Symbol symbol = symbolAllocator.newSymbol("rows", BIGINT); + VariableReferenceExpression variable = new VariableReferenceExpression(symbol.getName(), BIGINT); PlanNode source = new ValuesNode( idAllocator.getNextId(), - ImmutableList.of(symbol), + ImmutableList.of(variable), ImmutableList.of(ImmutableList.of(constant(0L, BIGINT)))); - return new OutputNode(idAllocator.getNextId(), source, ImmutableList.of("rows"), ImmutableList.of(symbol)); + return new OutputNode(idAllocator.getNextId(), source, ImmutableList.of("rows"), ImmutableList.of(variable)); } return createOutputPlan(planStatementWithoutOutput(analysis, statement), analysis); } @@ -252,9 +254,9 @@ private RelationPlan createExplainAnalyzePlan(Analysis analysis, Explain stateme RelationPlan underlyingPlan = planStatementWithoutOutput(analysis, statement.getStatement()); PlanNode root = underlyingPlan.getRoot(); Scope scope = analysis.getScope(statement); - Symbol outputSymbol = symbolAllocator.newSymbol(scope.getRelationType().getFieldByIndex(0)); - root = new ExplainAnalyzeNode(idAllocator.getNextId(), root, outputSymbol, statement.isVerbose()); - return new RelationPlan(root, scope, ImmutableList.of(outputSymbol)); + VariableReferenceExpression outputVariable = symbolAllocator.newVariable(scope.getRelationType().getFieldByIndex(0)); + root = new ExplainAnalyzeNode(idAllocator.getNextId(), root, outputVariable, statement.isVerbose()); + return new RelationPlan(root, scope, ImmutableList.of(outputVariable)); } private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStatement) @@ -263,42 +265,42 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme // Plan table scan Map columnHandles = metadata.getColumnHandles(session, targetTable); - ImmutableList.Builder tableScanOutputs = ImmutableList.builder(); - ImmutableMap.Builder symbolToColumnHandle = ImmutableMap.builder(); - ImmutableMap.Builder columnNameToSymbol = ImmutableMap.builder(); + ImmutableList.Builder tableScanOutputsBuilder = ImmutableList.builder(); + ImmutableMap.Builder variableToColumnHandle = ImmutableMap.builder(); + ImmutableMap.Builder columnNameToVariable = ImmutableMap.builder(); TableMetadata tableMetadata = metadata.getTableMetadata(session, targetTable); for (ColumnMetadata column : tableMetadata.getColumns()) { - Symbol symbol = symbolAllocator.newSymbol(column.getName(), column.getType()); - tableScanOutputs.add(symbol); - symbolToColumnHandle.put(symbol, columnHandles.get(column.getName())); - columnNameToSymbol.put(column.getName(), symbol); + VariableReferenceExpression variable = symbolAllocator.newVariable(column.getName(), column.getType()); + tableScanOutputsBuilder.add(variable); + variableToColumnHandle.put(variable, columnHandles.get(column.getName())); + columnNameToVariable.put(column.getName(), variable); } + List tableScanOutputs = tableScanOutputsBuilder.build(); TableStatisticsMetadata tableStatisticsMetadata = metadata.getStatisticsCollectionMetadata( session, targetTable.getConnectorId().getCatalogName(), tableMetadata.getMetadata()); - TableStatisticAggregation tableStatisticAggregation = statisticsAggregationPlanner.createStatisticsAggregation(tableStatisticsMetadata, columnNameToSymbol.build()); + TableStatisticAggregation tableStatisticAggregation = statisticsAggregationPlanner.createStatisticsAggregation(tableStatisticsMetadata, columnNameToVariable.build()); StatisticAggregations statisticAggregations = tableStatisticAggregation.getAggregations(); - List groupingSymbols = statisticAggregations.getGroupingSymbols(); PlanNode planNode = new StatisticsWriterNode( idAllocator.getNextId(), new AggregationNode( idAllocator.getNextId(), - new TableScanNode(idAllocator.getNextId(), targetTable, tableScanOutputs.build(), symbolToColumnHandle.build()), + new TableScanNode(idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build()), statisticAggregations.getAggregations(), - singleGroupingSet(groupingSymbols), + singleGroupingSet(statisticAggregations.getGroupingVariables()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), new StatisticsWriterNode.WriteStatisticsReference(targetTable), - symbolAllocator.newSymbol("rows", BIGINT), + symbolAllocator.newVariable("rows", BIGINT), tableStatisticsMetadata.getTableStatistics().contains(ROW_COUNT), tableStatisticAggregation.getDescriptor()); - return new RelationPlan(planNode, analysis.getScope(analyzeStatement), planNode.getOutputSymbols()); + return new RelationPlan(planNode, analysis.getScope(analyzeStatement), planNode.getOutputVariables()); } private RelationPlan createTableCreationPlan(Analysis analysis, Query query) @@ -352,7 +354,7 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement) if (column.isHidden()) { continue; } - Symbol output = symbolAllocator.newSymbol(column.getName(), column.getType()); + VariableReferenceExpression output = symbolAllocator.newVariable(column.getName(), column.getType()); int index = insert.getColumns().indexOf(columns.get(column.getName())); if (index < 0) { Expression cast = new Cast(new NullLiteral(), column.getType().getTypeSignature().toString()); @@ -379,7 +381,7 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement) .collect(toImmutableList()); Scope scope = Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(fields)).build(); - plan = new RelationPlan(projectNode, scope, projectNode.getOutputSymbols()); + plan = new RelationPlan(projectNode, scope, projectNode.getOutputVariables()); Optional newTableLayout = metadata.getInsertLayout(session, insert.getTarget()); String catalogName = insert.getTarget().getConnectorId().getCatalogName(); @@ -415,17 +417,17 @@ private RelationPlan createTableWriterPlan( } }); - List symbols = plan.getFieldMappings(); + List variables = plan.getFieldMappings(); Optional partitioningScheme = Optional.empty(); if (writeTableLayout.isPresent()) { - List partitionFunctionArguments = new ArrayList<>(); + List partitionFunctionArguments = new ArrayList<>(); writeTableLayout.get().getPartitionColumns().stream() .mapToInt(columnNames::indexOf) - .mapToObj(symbols::get) + .mapToObj(variables::get) .forEach(partitionFunctionArguments::add); - List outputLayout = new ArrayList<>(symbols); + List outputLayout = new ArrayList<>(variables); partitioningScheme = Optional.of(new PartitioningScheme( Partitioning.create(writeTableLayout.get().getPartitioning(), partitionFunctionArguments), @@ -433,11 +435,11 @@ private RelationPlan createTableWriterPlan( } if (!statisticsMetadata.isEmpty()) { - verify(columnNames.size() == symbols.size(), "columnNames.size() != symbols.size(): %s and %s", columnNames, symbols); - Map columnToSymbolMap = zip(columnNames.stream(), symbols.stream(), SimpleImmutableEntry::new) + verify(columnNames.size() == variables.size(), "columnNames.size() != variables.size(): %s and %s", columnNames, variables); + Map columnToVariableMap = zip(columnNames.stream(), plan.getFieldMappings().stream(), SimpleImmutableEntry::new) .collect(toImmutableMap(Entry::getKey, Entry::getValue)); - TableStatisticAggregation result = statisticsAggregationPlanner.createStatisticsAggregation(statisticsMetadata, columnToSymbolMap); + TableStatisticAggregation result = statisticsAggregationPlanner.createStatisticsAggregation(statisticsMetadata, columnToVariableMap); StatisticAggregations.Parts aggregations = result.getAggregations().createPartialAggregations(symbolAllocator, metadata.getFunctionManager()); @@ -451,10 +453,10 @@ private RelationPlan createTableWriterPlan( idAllocator.getNextId(), source, target, - symbolAllocator.newSymbol("partialrows", BIGINT), - symbolAllocator.newSymbol("fragment", VARBINARY), - symbolAllocator.newSymbol("tablecommitcontext", VARBINARY), - symbols, + symbolAllocator.newVariable("partialrows", BIGINT), + symbolAllocator.newVariable("fragment", VARBINARY), + symbolAllocator.newVariable("tablecommitcontext", VARBINARY), + plan.getFieldMappings(), columnNames, partitioningScheme, Optional.of(partialAggregation), @@ -464,11 +466,11 @@ private RelationPlan createTableWriterPlan( idAllocator.getNextId(), writerNode, target, - symbolAllocator.newSymbol("rows", BIGINT), + symbolAllocator.newVariable("rows", BIGINT), Optional.of(aggregations.getFinalAggregation()), Optional.of(result.getDescriptor())); - return new RelationPlan(commitNode, analysis.getRootScope(), commitNode.getOutputSymbols()); + return new RelationPlan(commitNode, analysis.getRootScope(), commitNode.getOutputVariables()); } TableFinishNode commitNode = new TableFinishNode( @@ -477,40 +479,41 @@ private RelationPlan createTableWriterPlan( idAllocator.getNextId(), source, target, - symbolAllocator.newSymbol("partialrows", BIGINT), - symbolAllocator.newSymbol("fragment", VARBINARY), - symbolAllocator.newSymbol("tablecommitcontext", VARBINARY), - symbols, + symbolAllocator.newVariable("partialrows", BIGINT), + symbolAllocator.newVariable("fragment", VARBINARY), + symbolAllocator.newVariable("tablecommitcontext", VARBINARY), + plan.getFieldMappings(), columnNames, partitioningScheme, Optional.empty(), Optional.empty()), target, - symbolAllocator.newSymbol("rows", BIGINT), + symbolAllocator.newVariable("rows", BIGINT), Optional.empty(), Optional.empty()); - return new RelationPlan(commitNode, analysis.getRootScope(), commitNode.getOutputSymbols()); + return new RelationPlan(commitNode, analysis.getRootScope(), commitNode.getOutputVariables()); } private RelationPlan createDeletePlan(Analysis analysis, Delete node) { - DeleteNode deleteNode = new QueryPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, session) + DeleteNode deleteNode = new QueryPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToVariableMap(analysis, symbolAllocator), metadata, session) .plan(node); TableFinishNode commitNode = new TableFinishNode( idAllocator.getNextId(), deleteNode, deleteNode.getTarget(), - symbolAllocator.newSymbol("rows", BIGINT), + symbolAllocator.newVariable("rows", BIGINT), Optional.empty(), Optional.empty()); - return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputSymbols()); + return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputVariables()); } private PlanNode createOutputPlan(RelationPlan plan, Analysis analysis) { ImmutableList.Builder outputs = ImmutableList.builder(); + ImmutableList.Builder outputVariables = ImmutableList.builder(); ImmutableList.Builder names = ImmutableList.builder(); int columnNumber = 0; @@ -522,16 +525,17 @@ private PlanNode createOutputPlan(RelationPlan plan, Analysis analysis) int fieldIndex = outputDescriptor.indexOf(field); Symbol symbol = plan.getSymbol(fieldIndex); outputs.add(symbol); + outputVariables.add(new VariableReferenceExpression(symbol.getName(), field.getType())); columnNumber++; } - return new OutputNode(idAllocator.getNextId(), plan.getRoot(), names.build(), outputs.build()); + return new OutputNode(idAllocator.getNextId(), plan.getRoot(), names.build(), outputVariables.build()); } private RelationPlan createRelationPlan(Analysis analysis, Query query) { - return new RelationPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, session) + return new RelationPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToVariableMap(analysis, symbolAllocator), metadata, session) .process(query, null); } @@ -563,9 +567,9 @@ private static List getOutputTableColumns(RelationPlan plan, Opt return columns.build(); } - private static Map, Symbol> buildLambdaDeclarationToSymbolMap(Analysis analysis, SymbolAllocator symbolAllocator) + private static Map, VariableReferenceExpression> buildLambdaDeclarationToVariableMap(Analysis analysis, SymbolAllocator symbolAllocator) { - Map, Symbol> resultMap = new LinkedHashMap<>(); + Map, VariableReferenceExpression> resultMap = new LinkedHashMap<>(); for (Entry, Type> entry : analysis.getTypes().entrySet()) { if (!(entry.getKey().getNode() instanceof LambdaArgumentDeclaration)) { continue; @@ -574,7 +578,7 @@ private static Map, Symbol> buildLambdaDeclar if (resultMap.containsKey(lambdaArgumentDeclaration)) { continue; } - resultMap.put(lambdaArgumentDeclaration, symbolAllocator.newSymbol(lambdaArgumentDeclaration.getNode(), entry.getValue())); + resultMap.put(lambdaArgumentDeclaration, symbolAllocator.newVariable(lambdaArgumentDeclaration.getNode(), entry.getValue())); } return resultMap; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LookupSymbolResolver.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LookupSymbolResolver.java index dcd0ef41cd0cd..3b5a0e976fbf9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LookupSymbolResolver.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LookupSymbolResolver.java @@ -15,11 +15,13 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.NullableValue; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableMap; import java.util.Map; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.util.Objects.requireNonNull; public class LookupSymbolResolver @@ -28,12 +30,12 @@ public class LookupSymbolResolver private final Map assignments; private final Map bindings; - public LookupSymbolResolver(Map assignments, Map bindings) + public LookupSymbolResolver(Map assignments, Map bindings) { requireNonNull(assignments, "assignments is null"); requireNonNull(bindings, "bindings is null"); - this.assignments = ImmutableMap.copyOf(assignments); + this.assignments = assignments.entrySet().stream().collect(toImmutableMap(entry -> new Symbol(entry.getKey().getName()), Map.Entry::getValue)); this.bindings = ImmutableMap.copyOf(bindings); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/OrderingScheme.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/OrderingScheme.java index 1988bc02447e4..0d8989a4b2006 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/OrderingScheme.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/OrderingScheme.java @@ -14,11 +14,14 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.spi.block.SortOrder; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; import java.util.List; import java.util.Map; @@ -27,16 +30,17 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class OrderingScheme { - private final List orderBy; - private final Map orderings; + private final List orderBy; + private final Map orderings; @JsonCreator - public OrderingScheme(@JsonProperty("orderBy") List orderBy, @JsonProperty("orderings") Map orderings) + public OrderingScheme(@JsonProperty("orderBy") List orderBy, @JsonProperty("orderings") Map orderings) { requireNonNull(orderBy, "orderBy is null"); requireNonNull(orderings, "orderings is null"); @@ -47,13 +51,13 @@ public OrderingScheme(@JsonProperty("orderBy") List orderBy, @JsonProper } @JsonProperty - public List getOrderBy() + public List getOrderBy() { return orderBy; } @JsonProperty - public Map getOrderings() + public Map getOrderings() { return orderings; } @@ -65,10 +69,16 @@ public List getOrderingList() .collect(toImmutableList()); } + public SortOrder getOrdering(VariableReferenceExpression variable) + { + checkArgument(orderings.containsKey(variable), format("No ordering for variable: %s", variable)); + return orderings.get(variable); + } + + @VisibleForTesting public SortOrder getOrdering(Symbol symbol) { - checkArgument(orderings.containsKey(symbol), format("No ordering for symbol: %s", symbol)); - return orderings.get(symbol); + return getOnlyElement(Maps.filterKeys(orderings, variable -> variable.getName().equals(symbol.getName())).values()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/Partitioning.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/Partitioning.java index c58ae05ee69fd..fef0e6771425a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/Partitioning.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/Partitioning.java @@ -15,7 +15,9 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.spi.predicate.NullableValue; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -32,6 +34,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; @@ -48,10 +51,10 @@ private Partitioning(PartitioningHandle handle, List arguments) this.arguments = ImmutableList.copyOf(requireNonNull(arguments, "arguments is null")); } - public static Partitioning create(PartitioningHandle handle, List columns) + public static Partitioning create(PartitioningHandle handle, List columns) { return new Partitioning(handle, columns.stream() - .map(ArgumentBinding::columnBinding) + .map(ArgumentBinding::new) .collect(toImmutableList())); } @@ -80,7 +83,17 @@ public Set getColumns() { return arguments.stream() .filter(ArgumentBinding::isVariable) - .map(ArgumentBinding::getColumn) + .map(ArgumentBinding::getVariableReference) + .map(VariableReferenceExpression::getName) + .map(Symbol::new) + .collect(toImmutableSet()); + } + + public Set getVariableReferences() + { + return arguments.stream() + .filter(ArgumentBinding::isVariable) + .map(ArgumentBinding::getVariableReference) .collect(toImmutableSet()); } @@ -100,9 +113,9 @@ public boolean isCompatibleWith( @Deprecated public boolean isCompatibleWith( Partitioning right, - Function> leftToRightMappings, - Function> leftConstantMapping, - Function> rightConstantMapping, + Function> leftToRightMappings, + Function> leftConstantMapping, + Function> rightConstantMapping, Metadata metadata, Session session) { @@ -141,9 +154,9 @@ public boolean isRefinedPartitioningOver( // Refined-over relation is reflexive. public boolean isRefinedPartitioningOver( Partitioning right, - Function> leftToRightMappings, - Function> leftConstantMapping, - Function> rightConstantMapping, + Function> leftToRightMappings, + Function> leftConstantMapping, + Function> rightConstantMapping, Metadata metadata, Session session) { @@ -167,23 +180,23 @@ public boolean isRefinedPartitioningOver( private static boolean isPartitionedWith( ArgumentBinding leftArgument, - Function> leftConstantMapping, + Function> leftConstantMapping, ArgumentBinding rightArgument, - Function> rightConstantMapping, - Function> leftToRightMappings) + Function> rightConstantMapping, + Function> leftToRightMappings) { if (leftArgument.isVariable()) { if (rightArgument.isVariable()) { // variable == variable - Set mappedColumns = leftToRightMappings.apply(leftArgument.getColumn()); - return mappedColumns.contains(rightArgument.getColumn()); + Set mappedColumns = leftToRightMappings.apply(leftArgument.getVariableReference()); + return mappedColumns.contains(rightArgument.getVariableReference()); } else { // variable == constant // Normally, this would be a false condition, but if we happen to have an external // mapping from the symbol to a constant value and that constant value matches the // right value, then we are co-partitioned. - Optional leftConstant = leftConstantMapping.apply(leftArgument.getColumn()); + Optional leftConstant = leftConstantMapping.apply(leftArgument.getVariableReference()); return leftConstant.isPresent() && leftConstant.get().equals(rightArgument.getConstant()); } } @@ -194,49 +207,49 @@ private static boolean isPartitionedWith( } else { // constant == variable - Optional rightConstant = rightConstantMapping.apply(rightArgument.getColumn()); + Optional rightConstant = rightConstantMapping.apply(rightArgument.getVariableReference()); return rightConstant.isPresent() && rightConstant.get().equals(leftArgument.getConstant()); } } } - public boolean isPartitionedOn(Collection columns, Set knownConstants) + public boolean isPartitionedOn(Collection columns, Set knownConstants) { // partitioned on (k_1, k_2, ..., k_n) => partitioned on (k_1, k_2, ..., k_n, k_n+1, ...) // can safely ignore all constant columns when comparing partition properties return arguments.stream() .filter(ArgumentBinding::isVariable) - .map(ArgumentBinding::getColumn) - .filter(symbol -> !knownConstants.contains(symbol)) + .map(ArgumentBinding::getVariableReference) + .filter(variable -> !knownConstants.contains(variable)) .allMatch(columns::contains); } - public boolean isEffectivelySinglePartition(Set knownConstants) + public boolean isEffectivelySinglePartition(Set knownConstants) { return isPartitionedOn(ImmutableSet.of(), knownConstants); } - public boolean isRepartitionEffective(Collection keys, Set knownConstants) + public boolean isRepartitionEffective(Collection keys, Set knownConstants) { - Set keysWithoutConstants = keys.stream() - .filter(symbol -> !knownConstants.contains(symbol)) + Set keysWithoutConstants = keys.stream() + .filter(variable -> !knownConstants.contains(variable)) .collect(toImmutableSet()); - Set nonConstantArgs = arguments.stream() + Set nonConstantArgs = arguments.stream() .filter(ArgumentBinding::isVariable) - .map(ArgumentBinding::getColumn) - .filter(symbol -> !knownConstants.contains(symbol)) + .map(ArgumentBinding::getVariableReference) + .filter(variable -> !knownConstants.contains(variable)) .collect(toImmutableSet()); return !nonConstantArgs.equals(keysWithoutConstants); } - public Partitioning translate(Function translator) + public Partitioning translate(Function translator) { return new Partitioning(handle, arguments.stream() .map(argument -> argument.translate(translator)) .collect(toImmutableList())); } - public Optional translate(Function> translator, Function> constants) + public Optional translate(Function> translator, Function> constants) { ImmutableList.Builder newArguments = ImmutableList.builder(); for (ArgumentBinding argument : arguments) { @@ -287,67 +300,65 @@ public String toString() @Immutable public static final class ArgumentBinding { - private final Symbol column; - private final NullableValue constant; + private final RowExpression rowExpression; @JsonCreator - public ArgumentBinding( - @JsonProperty("column") Symbol column, - @JsonProperty("constant") NullableValue constant) + public ArgumentBinding(@JsonProperty("rowExpression") RowExpression rowExpression) { - this.column = column; - this.constant = constant; - checkArgument((column == null) != (constant == null), "Either column or constant must be set"); + checkArgument(rowExpression instanceof VariableReferenceExpression || rowExpression instanceof ConstantExpression, "Expect either VariableReferenceExpression or ConstantExpression"); + this.rowExpression = requireNonNull(rowExpression, "rowExpression is null"); } - public static ArgumentBinding columnBinding(Symbol column) + @JsonProperty + public RowExpression getRowExpression() { - return new ArgumentBinding(requireNonNull(column, "column is null"), null); + return rowExpression; } - public static ArgumentBinding constantBinding(NullableValue constant) + public boolean isConstant() { - return new ArgumentBinding(null, requireNonNull(constant, "constant is null")); + return rowExpression instanceof ConstantExpression; } - public boolean isConstant() + public boolean isVariable() { - return constant != null; + return rowExpression instanceof VariableReferenceExpression; } - public boolean isVariable() + public VariableReferenceExpression getVariableReference() { - return column != null; + verify(rowExpression instanceof VariableReferenceExpression, "Expect the rowExpression to be a VariableReferenceExpression"); + return (VariableReferenceExpression) rowExpression; } - @JsonProperty public Symbol getColumn() { - return column; + verify(rowExpression instanceof VariableReferenceExpression, "Expect the rowExpression to be a VariableReferenceExpression"); + return new Symbol(getVariableReference().getName()); } - @JsonProperty - public NullableValue getConstant() + public ConstantExpression getConstant() { - return constant; + verify(rowExpression instanceof ConstantExpression, "Expect the rowExpression to be a ConstantExpression"); + return (ConstantExpression) rowExpression; } - public ArgumentBinding translate(Function translator) + public ArgumentBinding translate(Function translator) { if (isConstant()) { return this; } - return columnBinding(translator.apply(column)); + return new ArgumentBinding(translator.apply((VariableReferenceExpression) rowExpression)); } - public Optional translate(Function> translator, Function> constants) + public Optional translate(Function> translator, Function> constants) { if (isConstant()) { return Optional.of(this); } - Optional newColumn = translator.apply(column) - .map(ArgumentBinding::columnBinding); + Optional newColumn = translator.apply((VariableReferenceExpression) rowExpression) + .map(ArgumentBinding::new); if (newColumn.isPresent()) { return newColumn; } @@ -355,17 +366,17 @@ public Optional translate(Function> tr // As a last resort, check for a constant mapping for the symbol // Note: this MUST be last because we want to favor the symbol representation // as it makes further optimizations possible. - return constants.apply(column) - .map(ArgumentBinding::constantBinding); + return constants.apply((VariableReferenceExpression) rowExpression) + .map(ArgumentBinding::new); } @Override public String toString() { - if (constant != null) { - return constant.toString(); + if (rowExpression instanceof ConstantExpression) { + return rowExpression.toString(); } - return "\"" + column + "\""; + return "\"" + rowExpression.toString() + "\""; } @Override @@ -378,14 +389,13 @@ public boolean equals(Object o) return false; } ArgumentBinding that = (ArgumentBinding) o; - return Objects.equals(column, that.column) && - Objects.equals(constant, that.constant); + return Objects.equals(rowExpression, that.rowExpression); } @Override public int hashCode() { - return Objects.hash(column, constant); + return Objects.hash(rowExpression); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PartitioningScheme.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PartitioningScheme.java index 570ce889baceb..e2091170a97e5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PartitioningScheme.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PartitioningScheme.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -30,12 +31,12 @@ public class PartitioningScheme { private final Partitioning partitioning; - private final List outputLayout; - private final Optional hashColumn; + private final List outputLayout; + private final Optional hashColumn; private final boolean replicateNullsAndAny; private final Optional bucketToPartition; - public PartitioningScheme(Partitioning partitioning, List outputLayout) + public PartitioningScheme(Partitioning partitioning, List outputLayout) { this( partitioning, @@ -45,7 +46,7 @@ public PartitioningScheme(Partitioning partitioning, List outputLayout) Optional.empty()); } - public PartitioningScheme(Partitioning partitioning, List outputLayout, Optional hashColumn) + public PartitioningScheme(Partitioning partitioning, List outputLayout, Optional hashColumn) { this( partitioning, @@ -58,15 +59,15 @@ public PartitioningScheme(Partitioning partitioning, List outputLayout, @JsonCreator public PartitioningScheme( @JsonProperty("partitioning") Partitioning partitioning, - @JsonProperty("outputLayout") List outputLayout, - @JsonProperty("hashColumn") Optional hashColumn, + @JsonProperty("outputLayout") List outputLayout, + @JsonProperty("hashColumn") Optional hashColumn, @JsonProperty("replicateNullsAndAny") boolean replicateNullsAndAny, @JsonProperty("bucketToPartition") Optional bucketToPartition) { this.partitioning = requireNonNull(partitioning, "partitioning is null"); this.outputLayout = ImmutableList.copyOf(requireNonNull(outputLayout, "outputLayout is null")); - Set columns = partitioning.getColumns(); + Set columns = partitioning.getVariableReferences(); checkArgument(ImmutableSet.copyOf(outputLayout).containsAll(columns), "Output layout (%s) don't include all partition columns (%s)", outputLayout, columns); @@ -86,13 +87,13 @@ public Partitioning getPartitioning() } @JsonProperty - public List getOutputLayout() + public List getOutputLayout() { return outputLayout; } @JsonProperty - public Optional getHashColumn() + public Optional getHashColumn() { return hashColumn; } @@ -114,15 +115,15 @@ public PartitioningScheme withBucketToPartition(Optional bucketToPartitio return new PartitioningScheme(partitioning, outputLayout, hashColumn, replicateNullsAndAny, bucketToPartition); } - public PartitioningScheme translateOutputLayout(List newOutputLayout) + public PartitioningScheme translateOutputLayout(List newOutputLayout) { requireNonNull(newOutputLayout, "newOutputLayout is null"); checkArgument(newOutputLayout.size() == outputLayout.size()); - Partitioning newPartitioning = partitioning.translate(symbol -> newOutputLayout.get(outputLayout.indexOf(symbol))); + Partitioning newPartitioning = partitioning.translate(variable -> newOutputLayout.get(outputLayout.indexOf(variable))); - Optional newHashSymbol = hashColumn + Optional newHashSymbol = hashColumn .map(outputLayout::indexOf) .map(newOutputLayout::get); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java index c52cb3111fb7f..208e7d5391f02 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java @@ -14,11 +14,13 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -45,7 +47,7 @@ public PlanBuilder(TranslationMap translations, PlanNode root, List public TranslationMap copyTranslations() { - TranslationMap translations = new TranslationMap(getRelationPlan(), getAnalysis(), getTranslations().getLambdaDeclarationToSymbolMap()); + TranslationMap translations = new TranslationMap(getRelationPlan(), getAnalysis(), getTranslations().getLambdaDeclarationToVariableMap()); translations.copyMappingsFrom(getTranslations()); return translations; } @@ -75,7 +77,12 @@ public boolean canTranslate(Expression expression) return translations.containsSymbol(expression); } - public Symbol translate(Expression expression) + public VariableReferenceExpression translate(Expression expression) + { + return translations.get(expression); + } + + public VariableReferenceExpression translateToVariable(Expression expression) { return translations.get(expression); } @@ -97,18 +104,18 @@ public PlanBuilder appendProjections(Iterable expressions, SymbolAll Assignments.Builder projections = Assignments.builder(); // add an identity projection for underlying plan - for (Symbol symbol : getRoot().getOutputSymbols()) { - projections.put(symbol, symbol.toSymbolReference()); + for (VariableReferenceExpression variable : getRoot().getOutputVariables()) { + projections.put(variable, new SymbolReference(variable.getName())); } - ImmutableMap.Builder newTranslations = ImmutableMap.builder(); + ImmutableMap.Builder newTranslations = ImmutableMap.builder(); for (Expression expression : expressions) { - Symbol symbol = symbolAllocator.newSymbol(expression, getAnalysis().getTypeWithCoercions(expression)); - projections.put(symbol, translations.rewrite(expression)); - newTranslations.put(symbol, expression); + VariableReferenceExpression variable = symbolAllocator.newVariable(expression, getAnalysis().getTypeWithCoercions(expression)); + projections.put(variable, translations.rewrite(expression)); + newTranslations.put(variable, expression); } // Now append the new translations into the TranslationMap - for (Map.Entry entry : newTranslations.build().entrySet()) { + for (Map.Entry entry : newTranslations.build().entrySet()) { translations.put(entry.getValue(), entry.getKey()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java index e14285c958425..172af1f54f578 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java @@ -16,6 +16,7 @@ import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.operator.StageExecutionDescriptor; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,7 +30,6 @@ import javax.annotation.concurrent.Immutable; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; @@ -43,7 +43,7 @@ public class PlanFragment { private final PlanFragmentId id; private final PlanNode root; - private final Map symbols; + private final Set variables; private final PartitioningHandle partitioning; private final List tableScanSchedulingOrder; private final List types; @@ -58,7 +58,7 @@ public class PlanFragment public PlanFragment( @JsonProperty("id") PlanFragmentId id, @JsonProperty("root") PlanNode root, - @JsonProperty("symbols") Map symbols, + @JsonProperty("variables") Set variables, @JsonProperty("partitioning") PartitioningHandle partitioning, @JsonProperty("tableScanSchedulingOrder") List tableScanSchedulingOrder, @JsonProperty("partitioningScheme") PartitioningScheme partitioningScheme, @@ -69,7 +69,7 @@ public PlanFragment( { this.id = requireNonNull(id, "id is null"); this.root = requireNonNull(root, "root is null"); - this.symbols = requireNonNull(symbols, "symbols is null"); + this.variables = requireNonNull(variables, "variables is null"); this.partitioning = requireNonNull(partitioning, "partitioning is null"); this.tableScanSchedulingOrder = ImmutableList.copyOf(requireNonNull(tableScanSchedulingOrder, "tableScanSchedulingOrder is null")); this.stageExecutionDescriptor = requireNonNull(stageExecutionDescriptor, "stageExecutionDescriptor is null"); @@ -77,11 +77,11 @@ public PlanFragment( this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.jsonRepresentation = requireNonNull(jsonRepresentation, "jsonRepresentation is null"); - checkArgument(ImmutableSet.copyOf(root.getOutputSymbols()).containsAll(partitioningScheme.getOutputLayout()), - "Root node outputs (%s) does not include all fragment outputs (%s)", root.getOutputSymbols(), partitioningScheme.getOutputLayout()); + checkArgument(root.getOutputVariables().containsAll(partitioningScheme.getOutputLayout()), + "Root node outputs (%s) does not include all fragment outputs (%s)", root.getOutputVariables(), partitioningScheme.getOutputLayout()); types = partitioningScheme.getOutputLayout().stream() - .map(symbols::get) + .map(VariableReferenceExpression::getType) .collect(toImmutableList()); ImmutableList.Builder remoteSourceNodes = ImmutableList.builder(); @@ -104,9 +104,9 @@ public PlanNode getRoot() } @JsonProperty - public Map getSymbols() + public Set getVariables() { - return symbols; + return variables; } @JsonProperty @@ -202,7 +202,7 @@ public PlanFragment withBucketToPartition(Optional bucketToPartition) return new PlanFragment( id, root, - symbols, + variables, partitioning, tableScanSchedulingOrder, partitioningScheme.withBucketToPartition(bucketToPartition), @@ -217,7 +217,7 @@ public PlanFragment withFixedLifespanScheduleGroupedExecution(List c return new PlanFragment( id, root, - symbols, + variables, partitioning, tableScanSchedulingOrder, partitioningScheme, @@ -232,7 +232,7 @@ public PlanFragment withDynamicLifespanScheduleGroupedExecution(List return new PlanFragment( id, root, - symbols, + variables, partitioning, tableScanSchedulingOrder, partitioningScheme, @@ -247,7 +247,7 @@ public PlanFragment withRecoverableGroupedExecution(List capableTabl return new PlanFragment( id, root, - symbols, + variables, partitioning, tableScanSchedulingOrder, partitioningScheme, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java index e39c853be69af..bbfb91e6c7e73 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java @@ -39,9 +39,9 @@ import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -97,7 +97,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.sql.planner.SchedulingOrderVisitor.scheduleOrder; -import static com.facebook.presto.sql.planner.SymbolsExtractor.extractOutputSymbols; +import static com.facebook.presto.sql.planner.SymbolsExtractor.extractOutputVariables; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; @@ -110,14 +110,12 @@ import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonFragmentPlan; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Predicates.in; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.getOnlyElement; -import static com.google.common.collect.Maps.filterKeys; import static com.google.common.collect.Streams.stream; import static com.google.common.graph.Traverser.forTree; import static java.lang.String.format; @@ -148,6 +146,7 @@ public PlanFragmenter(Metadata metadata, NodePartitioningManager nodePartitionin public SubPlan createSubPlans(Session session, Plan plan, boolean forceSingleNode, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) { + SymbolAllocator symbolAllocator = new SymbolAllocator(plan.getTypes().allTypes()); Fragmenter fragmenter = new Fragmenter( session, metadata, @@ -159,7 +158,9 @@ public SubPlan createSubPlans(Session session, Plan plan, boolean forceSingleNod new SymbolAllocator(plan.getTypes().allTypes()), getTableWriterNodeIds(plan.getRoot())); - FragmentProperties properties = new FragmentProperties(new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getRoot().getOutputSymbols())); + FragmentProperties properties = new FragmentProperties(new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + plan.getRoot().getOutputVariables())); if (forceSingleNode || isForceSingleNodeOutput(session)) { properties = properties.setSingleNodeDistribution(); } @@ -273,7 +274,7 @@ private SubPlan reassignPartitioningHandleIfNecessaryHelper(Session session, Sub PlanFragment newFragment = new PlanFragment( fragment.getId(), newRoot, - fragment.getSymbols(), + fragment.getVariables(), fragment.getPartitioning(), fragment.getTableScanSchedulingOrder(), new PartitioningScheme( @@ -361,8 +362,8 @@ private SubPlan buildFragment(PlanNode root, FragmentProperties properties, Plan schedulingOrder, properties.getPartitionedSources()); - Map fragmentSymbolTypes = filterKeys(symbolAllocator.getTypes().allTypes(), in(extractOutputSymbols(root))); - planSanityChecker.validatePlanFragment(root, session, metadata, sqlParser, TypeProvider.viewOf(fragmentSymbolTypes), warningCollector); + Set fragmentVariableTypes = extractOutputVariables(root); + planSanityChecker.validatePlanFragment(root, session, metadata, sqlParser, TypeProvider.fromVariables(fragmentVariableTypes), warningCollector); Set tableWriterNodeIds = getTableWriterNodeIds(root); boolean outputTableWriterFragment = tableWriterNodeIds.stream().anyMatch(outputTableWriterNodeIds::contains); @@ -377,14 +378,14 @@ private SubPlan buildFragment(PlanNode root, FragmentProperties properties, Plan PlanFragment fragment = new PlanFragment( fragmentId, root, - fragmentSymbolTypes, + fragmentVariableTypes, properties.getPartitioningHandle(), schedulingOrder, properties.getPartitioningScheme(), StageExecutionDescriptor.ungroupedExecution(), outputTableWriterFragment, statsAndCosts.getForSubplan(root), - Optional.of(jsonFragmentPlan(root, fragmentSymbolTypes, metadata.getFunctionManager(), session))); + Optional.of(jsonFragmentPlan(root, fragmentVariableTypes, metadata.getFunctionManager(), session))); return new SubPlan(fragment, properties.getChildren()); } @@ -496,7 +497,7 @@ else if (exchange.getType() == ExchangeNode.Type.REPARTITION) { .map(PlanFragment::getId) .collect(toImmutableList()); - return new RemoteSourceNode(exchange.getId(), childrenIds, exchange.getOutputSymbols(), exchange.getOrderingScheme(), exchange.getType()); + return new RemoteSourceNode(exchange.getId(), childrenIds, exchange.getOutputVariables(), exchange.getOrderingScheme(), exchange.getType()); } private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, RewriteContext context) @@ -512,24 +513,24 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite .orElseThrow(() -> new IllegalArgumentException("Unsupported partitioning handle: " + partitioningHandle)); Partitioning partitioning = partitioningScheme.getPartitioning(); - PartitioningSymbolAssignments partitioningSymbolAssignments = assignPartitioningSymbols(partitioning); - Map symbolToColumnMap = assignTemporaryTableColumnNames(exchange.getOutputSymbols(), partitioningSymbolAssignments.getConstants().keySet()); - List partitioningSymbols = partitioningSymbolAssignments.getSymbols(); - List partitionColumns = partitioningSymbols.stream() - .map(symbol -> symbolToColumnMap.get(symbol).getName()) + PartitioningVariableAssignments partitioningVariableAssignments = assignPartitioningVariables(partitioning); + Map variableToColumnMap = assignTemporaryTableColumnNames(exchange.getOutputVariables(), partitioningVariableAssignments.getConstants().keySet()); + List partitioningVariables = partitioningVariableAssignments.getVariables(); + List partitionColumns = partitioningVariables.stream() + .map(variable -> variableToColumnMap.get(variable).getName()) .collect(toImmutableList()); PartitioningMetadata partitioningMetadata = new PartitioningMetadata(partitioningHandle, partitionColumns); TableHandle temporaryTableHandle = metadata.createTemporaryTable( session, connectorId.getCatalogName(), - ImmutableList.copyOf(symbolToColumnMap.values()), + ImmutableList.copyOf(variableToColumnMap.values()), Optional.of(partitioningMetadata)); TableScanNode scan = createTemporaryTableScan( temporaryTableHandle, - exchange.getOutputSymbols(), - symbolToColumnMap, + exchange.getOutputVariables(), + variableToColumnMap, partitioningMetadata); checkArgument( @@ -537,14 +538,16 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite "materialized remote exchange is not supported when replicateNullsAndAny is needed"); TableFinishNode write = createTemporaryTableWrite( temporaryTableHandle, - symbolToColumnMap, - exchange.getOutputSymbols(), + variableToColumnMap, + exchange.getOutputVariables(), exchange.getInputs(), exchange.getSources(), - partitioningSymbolAssignments.getConstants(), + partitioningVariableAssignments.getConstants(), partitioningMetadata); - FragmentProperties writeProperties = new FragmentProperties(new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), write.getOutputSymbols())); + FragmentProperties writeProperties = new FragmentProperties(new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + write.getOutputVariables())); writeProperties.setCoordinatorOnlyDistribution(); List children = ImmutableList.of(buildSubPlan(write, writeProperties, context)); @@ -553,33 +556,33 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite return visitTableScan(scan, context); } - private PartitioningSymbolAssignments assignPartitioningSymbols(Partitioning partitioning) + private PartitioningVariableAssignments assignPartitioningVariables(Partitioning partitioning) { - ImmutableList.Builder symbols = ImmutableList.builder(); - ImmutableMap.Builder constants = ImmutableMap.builder(); + ImmutableList.Builder variables = ImmutableList.builder(); + ImmutableMap.Builder constants = ImmutableMap.builder(); for (ArgumentBinding argumentBinding : partitioning.getArguments()) { - Symbol symbol; + VariableReferenceExpression variable; if (argumentBinding.isConstant()) { - NullableValue constant = argumentBinding.getConstant(); + ConstantExpression constant = argumentBinding.getConstant(); Expression expression = literalEncoder.toExpression(constant.getValue(), constant.getType()); - symbol = symbolAllocator.newSymbol(expression, constant.getType()); - constants.put(symbol, expression); + variable = symbolAllocator.newVariable(expression, constant.getType()); + constants.put(variable, expression); } else { - symbol = argumentBinding.getColumn(); + variable = argumentBinding.getVariableReference(); } - symbols.add(symbol); + variables.add(variable); } - return new PartitioningSymbolAssignments(symbols.build(), constants.build()); + return new PartitioningVariableAssignments(variables.build(), constants.build()); } - private Map assignTemporaryTableColumnNames(Collection outputSymbols, Collection constantPartitioningSymbols) + private Map assignTemporaryTableColumnNames(Collection outputVariables, Collection constantPartitioningVariables) { - ImmutableMap.Builder result = ImmutableMap.builder(); + ImmutableMap.Builder result = ImmutableMap.builder(); int column = 0; - for (Symbol outputSymbol : concat(outputSymbols, constantPartitioningSymbols)) { - String columnName = format("_c%d_%s", column, outputSymbol.getName()); - result.put(outputSymbol, new ColumnMetadata(columnName, symbolAllocator.getTypes().get(outputSymbol))); + for (VariableReferenceExpression outputVariable : concat(outputVariables, constantPartitioningVariables)) { + String columnName = format("_c%d_%s", column, outputVariable.getName()); + result.put(outputVariable, new ColumnMetadata(columnName, outputVariable.getType())); column++; } return result.build(); @@ -587,13 +590,13 @@ private Map assignTemporaryTableColumnNames(Collection outputSymbols, - Map symbolToColumnMap, + List outputVariables, + Map variableToColumnMap, PartitioningMetadata expectedPartitioningMetadata) { Map columnHandles = metadata.getColumnHandles(session, tableHandle); - Map outputColumns = outputSymbols.stream() - .collect(toImmutableMap(identity(), symbolToColumnMap::get)); + Map outputColumns = outputVariables.stream() + .collect(toImmutableMap(identity(), variableToColumnMap::get)); Set outputColumnHandles = outputColumns.values().stream() .map(ColumnMetadata::getName) .map(columnHandles::get) @@ -609,13 +612,13 @@ private TableScanNode createTemporaryTableScan( .collect(toImmutableList())); verify(selectedLayout.getLayout().getTablePartitioning().equals(Optional.of(expectedPartitioning)), "invalid temporary table partitioning"); - Map assignments = outputSymbols.stream() - .collect(toImmutableMap(identity(), symbol -> columnHandles.get(outputColumns.get(symbol).getName()))); + Map assignments = outputVariables.stream() + .collect(toImmutableMap(identity(), variable -> columnHandles.get(outputColumns.get(variable).getName()))); return new TableScanNode( idAllocator.getNextId(), selectedLayout.getLayout().getNewTableHandle(), - outputSymbols, + outputVariables, assignments, TupleDomain.all(), TupleDomain.all(), @@ -624,27 +627,27 @@ private TableScanNode createTemporaryTableScan( private TableFinishNode createTemporaryTableWrite( TableHandle tableHandle, - Map symbolToColumnMap, - List outputs, - List> inputs, + Map variableToColumnMap, + List outputs, + List> inputs, List sources, - Map constantExpressions, + Map constantExpressions, PartitioningMetadata partitioningMetadata) { if (!constantExpressions.isEmpty()) { - List constantSymbols = ImmutableList.copyOf(constantExpressions.keySet()); + List constantVariables = ImmutableList.copyOf(constantExpressions.keySet()); // update outputs - outputs = ImmutableList.builder() + outputs = ImmutableList.builder() .addAll(outputs) - .addAll(constantSymbols) + .addAll(constantVariables) .build(); // update inputs inputs = inputs.stream() - .map(input -> ImmutableList.builder() + .map(input -> ImmutableList.builder() .addAll(input) - .addAll(constantSymbols) + .addAll(constantVariables) .build()) .collect(toImmutableList()); @@ -652,8 +655,8 @@ private TableFinishNode createTemporaryTableWrite( sources = sources.stream() .map(source -> { Assignments.Builder assignments = Assignments.builder(); - assignments.putIdentities(source.getOutputSymbols()); - constantSymbols.forEach(symbol -> assignments.put(symbol, constantExpressions.get(symbol))); + assignments.putIdentities(source.getOutputVariables()); + constantVariables.forEach(variable -> assignments.put(variable, constantExpressions.get(variable))); return new ProjectNode(idAllocator.getNextId(), source, assignments.build()); }) .collect(toImmutableList()); @@ -668,15 +671,15 @@ private TableFinishNode createTemporaryTableWrite( ConnectorNewTableLayout expectedNewTableLayout = new ConnectorNewTableLayout(partitioningHandle.getConnectorHandle(), partitionColumns); verify(insertLayout.getLayout().equals(expectedNewTableLayout), "unexpected new table layout"); - Map columnNameToSymbol = symbolToColumnMap.entrySet().stream() + Map columnNameToVariable = variableToColumnMap.entrySet().stream() .collect(toImmutableMap(entry -> entry.getValue().getName(), Map.Entry::getKey)); - List partitioningSymbols = partitionColumns.stream() - .map(columnNameToSymbol::get) + List partitioningVariables = partitionColumns.stream() + .map(columnNameToVariable::get) .collect(toImmutableList()); InsertTableHandle insertTableHandle = metadata.beginInsert(session, tableHandle); List outputColumnNames = outputs.stream() - .map(symbolToColumnMap::get) + .map(variableToColumnMap::get) .map(ColumnMetadata::getName) .collect(toImmutableList()); @@ -701,7 +704,7 @@ private TableFinishNode createTemporaryTableWrite( REPARTITION, REMOTE_STREAMING, new PartitioningScheme( - Partitioning.create(partitioningHandle, partitioningSymbols), + Partitioning.create(partitioningHandle, partitioningVariables), outputs, Optional.empty(), false, @@ -710,13 +713,13 @@ private TableFinishNode createTemporaryTableWrite( inputs, Optional.empty())), insertHandle, - symbolAllocator.newSymbol("partialrows", BIGINT), - symbolAllocator.newSymbol("fragment", VARBINARY), - symbolAllocator.newSymbol("tablecommitcontext", VARBINARY), + symbolAllocator.newVariable("partialrows", BIGINT), + symbolAllocator.newVariable("fragment", VARBINARY), + symbolAllocator.newVariable("tablecommitcontext", VARBINARY), outputs, outputColumnNames, Optional.of(new PartitioningScheme( - Partitioning.create(partitioningHandle, partitioningSymbols), + Partitioning.create(partitioningHandle, partitioningVariables), outputs, Optional.empty(), false, @@ -724,7 +727,7 @@ private TableFinishNode createTemporaryTableWrite( Optional.empty(), Optional.empty()))), insertHandle, - symbolAllocator.newSymbol("rows", BIGINT), + symbolAllocator.newVariable("rows", BIGINT), Optional.empty(), Optional.empty()); } @@ -1200,7 +1203,7 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) return new TableScanNode( node.getId(), newTableHandle, - node.getOutputSymbols(), + node.getOutputVariables(), node.getAssignments(), node.getCurrentConstraint(), node.getEnforcedConstraint(), @@ -1208,26 +1211,26 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) } } - private static class PartitioningSymbolAssignments + private static class PartitioningVariableAssignments { - private final List symbols; - private final Map constants; + private final List variables; + private final Map constants; - private PartitioningSymbolAssignments(List symbols, Map constants) + private PartitioningVariableAssignments(List variables, Map constants) { - this.symbols = ImmutableList.copyOf(requireNonNull(symbols, "symbols is null")); + this.variables = ImmutableList.copyOf(requireNonNull(variables, "variables is null")); this.constants = ImmutableMap.copyOf(requireNonNull(constants, "constants is null")); checkArgument( - ImmutableSet.copyOf(symbols).containsAll(constants.keySet()), - "partitioningSymbols list must contain all partitioning symbols including constants"); + ImmutableSet.copyOf(variables).containsAll(constants.keySet()), + "partitioningVariables list must contain all partitioning variables including constants"); } - public List getSymbols() + public List getVariables() { - return symbols; + return variables; } - public Map getConstants() + public Map getConstants() { return constants; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index aba526382088a..82d85a9c7a59b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -14,18 +14,17 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.spi.block.SortOrder; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; public class PlannerUtils { @@ -45,28 +44,18 @@ public static SortOrder toSortOrder(SortItem sortItem) return SortOrder.DESC_NULLS_LAST; } - public static OrderingScheme toOrderingScheme(List sortItems) - { - return toOrderingScheme(sortItems, item -> { - checkArgument(item instanceof SymbolReference, "must be symbol reference"); - return new Symbol(((SymbolReference) item).getName()); - }); - } - - public static OrderingScheme toOrderingScheme(List sortItems, Function translator) + public static OrderingScheme toOrderingScheme(List sortItems, TypeProvider typeProvider) { // The logic is similar to QueryPlanner::sort - Map orderings = new LinkedHashMap<>(); + Map orderings = new LinkedHashMap<>(); for (SortItem item : sortItems) { - Symbol symbol = translator.apply(item.getSortKey()); + Expression sortKey = item.getSortKey(); + checkArgument(sortKey instanceof SymbolReference, "must be symbol reference"); + Symbol symbol = Symbol.from(sortKey); + VariableReferenceExpression variable = new VariableReferenceExpression(symbol.getName(), typeProvider.get(symbol)); // don't override existing keys, i.e. when "ORDER BY a ASC, a DESC" is specified - orderings.putIfAbsent(symbol, toSortOrder(item)); + orderings.putIfAbsent(variable, toSortOrder(item)); } - return new OrderingScheme(orderings.keySet().stream().collect(toImmutableList()), orderings); - } - - public static OrderingScheme toOrderingScheme(OrderBy orderBy) - { - return toOrderingScheme(orderBy.getSortItems()); + return new OrderingScheme(ImmutableList.copyOf(orderings.keySet()), orderings); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index c4ab08b92fd77..b4f5923f4f8ff 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Field; @@ -79,6 +80,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; +import static com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme; import static com.facebook.presto.sql.planner.PlannerUtils.toSortOrder; import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.toBoundType; import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.toWindowType; @@ -98,7 +100,7 @@ class QueryPlanner private final Analysis analysis; private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; - private final Map, Symbol> lambdaDeclarationToSymbolMap; + private final Map, VariableReferenceExpression> lambdaDeclarationToVariableMap; private final Metadata metadata; private final Session session; private final SubqueryPlanner subqueryPlanner; @@ -107,24 +109,24 @@ class QueryPlanner Analysis analysis, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, - Map, Symbol> lambdaDeclarationToSymbolMap, + Map, VariableReferenceExpression> lambdaDeclarationToVariableMap, Metadata metadata, Session session) { requireNonNull(analysis, "analysis is null"); requireNonNull(symbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); - requireNonNull(lambdaDeclarationToSymbolMap, "lambdaDeclarationToSymbolMap is null"); + requireNonNull(lambdaDeclarationToVariableMap, "lambdaDeclarationToVariableMap is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(session, "session is null"); this.analysis = analysis; this.symbolAllocator = symbolAllocator; this.idAllocator = idAllocator; - this.lambdaDeclarationToSymbolMap = lambdaDeclarationToSymbolMap; + this.lambdaDeclarationToVariableMap = lambdaDeclarationToVariableMap; this.metadata = metadata; this.session = session; - this.subqueryPlanner = new SubqueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, session); + this.subqueryPlanner = new SubqueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session); } public RelationPlan plan(Query query) @@ -141,10 +143,7 @@ public RelationPlan plan(Query query) builder = project(builder, analysis.getOutputExpressions(query)); builder = limit(builder, query); - return new RelationPlan( - builder.getRoot(), - analysis.getScope(query), - computeOutputs(builder, analysis.getOutputExpressions(query))); + return new RelationPlan(builder.getRoot(), analysis.getScope(query), computeOutputs(builder, analysis.getOutputExpressions(query))); } public RelationPlan plan(QuerySpecification node) @@ -191,10 +190,7 @@ public RelationPlan plan(QuerySpecification node) builder = project(builder, outputs); builder = limit(builder, node); - return new RelationPlan( - builder.getRoot(), - analysis.getScope(node), - computeOutputs(builder, outputs)); + return new RelationPlan(builder.getRoot(), analysis.getScope(node), computeOutputs(builder, outputs)); } public DeleteNode plan(Delete node) @@ -205,29 +201,31 @@ public DeleteNode plan(Delete node) Type rowIdType = metadata.getColumnMetadata(session, handle, rowIdHandle).getType(); // add table columns - ImmutableList.Builder outputSymbols = ImmutableList.builder(); - ImmutableMap.Builder columns = ImmutableMap.builder(); + ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); + ImmutableMap.Builder columns = ImmutableMap.builder(); ImmutableList.Builder fields = ImmutableList.builder(); for (Field field : descriptor.getAllFields()) { - Symbol symbol = symbolAllocator.newSymbol(field.getName().get(), field.getType()); - outputSymbols.add(symbol); - columns.put(symbol, analysis.getColumn(field)); + VariableReferenceExpression variable = symbolAllocator.newVariable(field.getName().get(), field.getType()); + outputVariablesBuilder.add(variable); + columns.put(variable, analysis.getColumn(field)); fields.add(field); } // add rowId column Field rowIdField = Field.newUnqualified(Optional.empty(), rowIdType); - Symbol rowIdSymbol = symbolAllocator.newSymbol("$rowId", rowIdField.getType()); - outputSymbols.add(rowIdSymbol); - columns.put(rowIdSymbol, rowIdHandle); + VariableReferenceExpression rowIdVariable = symbolAllocator.newVariable("$rowId", rowIdField.getType()); + outputVariablesBuilder.add(rowIdVariable); + columns.put(rowIdVariable, rowIdHandle); fields.add(rowIdField); // create table scan - PlanNode tableScan = new TableScanNode(idAllocator.getNextId(), handle, outputSymbols.build(), columns.build()); + List outputVariables = outputVariablesBuilder.build(); + List outputSymbols = outputVariables.stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableList()); + PlanNode tableScan = new TableScanNode(idAllocator.getNextId(), handle, outputVariables, columns.build()); Scope scope = Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(fields.build())).build(); - RelationPlan relationPlan = new RelationPlan(tableScan, scope, outputSymbols.build()); + RelationPlan relationPlan = new RelationPlan(tableScan, scope, outputVariables); - TranslationMap translations = new TranslationMap(relationPlan, analysis, lambdaDeclarationToSymbolMap); + TranslationMap translations = new TranslationMap(relationPlan, analysis, lambdaDeclarationToVariableMap); translations.setFieldMappings(relationPlan.getFieldMappings()); PlanBuilder builder = new PlanBuilder(translations, relationPlan.getRoot(), analysis.getParameters()); @@ -237,26 +235,27 @@ public DeleteNode plan(Delete node) } // create delete node - Symbol rowId = builder.translate(new FieldReference(relationPlan.getDescriptor().indexOf(rowIdField))); - List outputs = ImmutableList.of( - symbolAllocator.newSymbol("partialrows", BIGINT), - symbolAllocator.newSymbol("fragment", VARBINARY)); + VariableReferenceExpression rowId = new VariableReferenceExpression(builder.translate(new FieldReference(relationPlan.getDescriptor().indexOf(rowIdField))).getName(), rowIdField.getType()); + List deleteNodeOutputVariables = ImmutableList.of( + symbolAllocator.newVariable("partialrows", BIGINT), + symbolAllocator.newVariable("fragment", VARBINARY)); - return new DeleteNode(idAllocator.getNextId(), builder.getRoot(), new DeleteHandle(handle, metadata.getTableMetadata(session, handle).getTable()), rowId, outputs); + List deleteNodeOutputSymbols = deleteNodeOutputVariables.stream().map(variable -> new Symbol(variable.getName())).collect(toImmutableList()); + return new DeleteNode(idAllocator.getNextId(), builder.getRoot(), new DeleteHandle(handle, metadata.getTableMetadata(session, handle).getTable()), rowId, deleteNodeOutputVariables); } - private static List computeOutputs(PlanBuilder builder, List outputExpressions) + private static List computeOutputs(PlanBuilder builder, List outputExpressions) { - ImmutableList.Builder outputSymbols = ImmutableList.builder(); + ImmutableList.Builder outputs = ImmutableList.builder(); for (Expression expression : outputExpressions) { - outputSymbols.add(builder.translate(expression)); + outputs.add(builder.translate(expression)); } - return outputSymbols.build(); + return outputs.build(); } private PlanBuilder planQueryBody(Query query) { - RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, session) + RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session) .process(query.getQueryBody(), null); return planBuilderFor(relationPlan); @@ -267,7 +266,7 @@ private PlanBuilder planFrom(QuerySpecification node) RelationPlan relationPlan; if (node.getFrom().isPresent()) { - relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, session) + relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session) .process(node.getFrom().get(), null); } else { @@ -279,21 +278,21 @@ private PlanBuilder planFrom(QuerySpecification node) private PlanBuilder planBuilderFor(PlanBuilder builder, Scope scope, Iterable expressionsToRemap) { - Map expressionsToSymbols = symbolsForExpressions(builder, expressionsToRemap); + Map expressionsToVariables = variablesForExpressions(builder, expressionsToRemap); PlanBuilder newBuilder = planBuilderFor(builder, scope); - expressionsToSymbols.entrySet() + expressionsToVariables.entrySet() .forEach(entry -> newBuilder.getTranslations().put(entry.getKey(), entry.getValue())); return newBuilder; } private PlanBuilder planBuilderFor(PlanBuilder builder, Scope scope) { - return planBuilderFor(new RelationPlan(builder.getRoot(), scope, builder.getRoot().getOutputSymbols())); + return planBuilderFor(new RelationPlan(builder.getRoot(), scope, builder.getRoot().getOutputVariables())); } private PlanBuilder planBuilderFor(RelationPlan relationPlan) { - TranslationMap translations = new TranslationMap(relationPlan, analysis, lambdaDeclarationToSymbolMap); + TranslationMap translations = new TranslationMap(relationPlan, analysis, lambdaDeclarationToVariableMap); // Make field->symbol mapping from underlying relation plan available for translations // This makes it possible to rewrite FieldOrExpressions that reference fields from the FROM clause directly @@ -332,20 +331,20 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression private PlanBuilder project(PlanBuilder subPlan, Iterable expressions) { - TranslationMap outputTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap); + TranslationMap outputTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToVariableMap); Assignments.Builder projections = Assignments.builder(); for (Expression expression : expressions) { if (expression instanceof SymbolReference) { - Symbol symbol = Symbol.from(expression); - projections.put(symbol, expression); - outputTranslations.put(expression, symbol); + VariableReferenceExpression variable = symbolAllocator.toVariableReference(Symbol.from(expression)); + projections.put(variable, expression); + outputTranslations.put(expression, variable); continue; } - Symbol symbol = symbolAllocator.newSymbol(expression, analysis.getTypeWithCoercions(expression)); - projections.put(symbol, subPlan.rewrite(expression)); - outputTranslations.put(expression, symbol); + VariableReferenceExpression variable = symbolAllocator.newVariable(expression, analysis.getTypeWithCoercions(expression)); + projections.put(variable, subPlan.rewrite(expression)); + outputTranslations.put(expression, variable); } return new PlanBuilder(outputTranslations, new ProjectNode( @@ -355,14 +354,14 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression analysis.getParameters()); } - private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) + private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) { - ImmutableMap.Builder projections = ImmutableMap.builder(); + ImmutableMap.Builder projections = ImmutableMap.builder(); for (Expression expression : expressions) { Type type = analysis.getType(expression); Type coercion = analysis.getCoercion(expression); - Symbol symbol = symbolAllocator.newSymbol(expression, firstNonNull(coercion, type)); + VariableReferenceExpression variable = symbolAllocator.newVariable(expression, firstNonNull(coercion, type)); Expression rewritten = subPlan.rewrite(expression); if (coercion != null) { rewritten = new Cast( @@ -371,8 +370,8 @@ private Map coerce(Iterable expression false, metadata.getTypeManager().isTypeOnlyCoercion(type, coercion)); } - projections.put(symbol, rewritten); - translations.put(expression, symbol); + projections.put(variable, rewritten); + translations.put(expression, variable); } return projections.build(); @@ -380,7 +379,7 @@ private Map coerce(Iterable expression private PlanBuilder explicitCoercionFields(PlanBuilder subPlan, Iterable alreadyCoerced, Iterable uncoerced) { - TranslationMap translations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap); + TranslationMap translations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToVariableMap); Assignments.Builder projections = Assignments.builder(); projections.putAll(coerce(uncoerced, subPlan, translations)); @@ -390,14 +389,14 @@ private PlanBuilder explicitCoercionFields(PlanBuilder subPlan, Iterable alreadyCoerced, Iterable uncoerced) + private PlanBuilder explicitCoercionSymbols(PlanBuilder subPlan, List alreadyCoerced, Iterable uncoerced) { TranslationMap translations = subPlan.copyTranslations(); @@ -465,23 +464,23 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) // 2. Aggregate // 2.a. Rewrite aggregate arguments - TranslationMap argumentTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap); + TranslationMap argumentTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToVariableMap); - ImmutableList.Builder aggregationArgumentsBuilder = ImmutableList.builder(); + ImmutableList.Builder aggregationArgumentsBuilder = ImmutableList.builder(); for (Expression argument : arguments.build()) { - Symbol symbol = subPlan.translate(argument); - argumentTranslations.put(argument, symbol); - aggregationArgumentsBuilder.add(symbol); + VariableReferenceExpression variable = subPlan.translate(argument); + argumentTranslations.put(argument, variable); + aggregationArgumentsBuilder.add(variable); } - List aggregationArguments = aggregationArgumentsBuilder.build(); + List aggregationArguments = aggregationArgumentsBuilder.build(); // 2.b. Rewrite grouping columns - TranslationMap groupingTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap); - Map groupingSetMappings = new LinkedHashMap<>(); + TranslationMap groupingTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToVariableMap); + Map groupingSetMappings = new LinkedHashMap<>(); for (Expression expression : groupByExpressions) { - Symbol input = subPlan.translate(expression); - Symbol output = symbolAllocator.newSymbol(expression, analysis.getTypeWithCoercions(expression), "gid"); + VariableReferenceExpression input = subPlan.translate(expression); + VariableReferenceExpression output = symbolAllocator.newVariable(expression, analysis.getTypeWithCoercions(expression), "gid"); groupingTranslations.put(expression, output); groupingSetMappings.put(output, input); } @@ -489,7 +488,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) // This tracks the grouping sets before complex expressions are considered (see comments below) // It's also used to compute the descriptors needed to implement grouping() List> columnOnlyGroupingSets = ImmutableList.of(ImmutableSet.of()); - List> groupingSets = ImmutableList.of(ImmutableList.of()); + List> groupingSets = ImmutableList.of(ImmutableList.of()); if (node.getGroupBy().isPresent()) { // For the purpose of "distinct", we need to canonicalize column references that may have varying @@ -509,9 +508,9 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) } // add in the complex expressions an turn materialize the grouping sets in terms of plan columns - ImmutableList.Builder> groupingSetBuilder = ImmutableList.builder(); + ImmutableList.Builder> groupingSetBuilder = ImmutableList.builder(); for (Set groupingSet : columnOnlyGroupingSets) { - ImmutableList.Builder columns = ImmutableList.builder(); + ImmutableList.Builder columns = ImmutableList.builder(); groupingSetAnalysis.getComplexExpressions().stream() .map(groupingTranslations::get) .forEach(columns::add); @@ -527,30 +526,30 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) } // 2.c. Generate GroupIdNode (multiple grouping sets) or ProjectNode (single grouping set) - Optional groupIdSymbol = Optional.empty(); + Optional groupIdVariable = Optional.empty(); if (groupingSets.size() > 1) { - groupIdSymbol = Optional.of(symbolAllocator.newSymbol("groupId", BIGINT)); - GroupIdNode groupId = new GroupIdNode(idAllocator.getNextId(), subPlan.getRoot(), groupingSets, groupingSetMappings, aggregationArguments, groupIdSymbol.get()); + groupIdVariable = Optional.of(symbolAllocator.newVariable("groupId", BIGINT)); + GroupIdNode groupId = new GroupIdNode(idAllocator.getNextId(), subPlan.getRoot(), groupingSets, groupingSetMappings, aggregationArguments, groupIdVariable.get()); subPlan = new PlanBuilder(groupingTranslations, groupId, analysis.getParameters()); } else { Assignments.Builder assignments = Assignments.builder(); aggregationArguments.forEach(assignments::putIdentity); - groupingSetMappings.forEach((key, value) -> assignments.put(key, value.toSymbolReference())); + groupingSetMappings.forEach((key, value) -> assignments.put(key, new SymbolReference(value.getName()))); ProjectNode project = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build()); subPlan = new PlanBuilder(groupingTranslations, project, analysis.getParameters()); } - TranslationMap aggregationTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap); + TranslationMap aggregationTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToVariableMap); aggregationTranslations.copyMappingsFrom(groupingTranslations); // 2.d. Rewrite aggregates - ImmutableMap.Builder aggregationsBuilder = ImmutableMap.builder(); + ImmutableMap.Builder aggregationsBuilder = ImmutableMap.builder(); boolean needPostProjectionCoercion = false; for (FunctionCall aggregate : analysis.getAggregates(node)) { Expression rewritten = argumentTranslations.rewrite(aggregate); - Symbol newSymbol = symbolAllocator.newSymbol(rewritten, analysis.getType(aggregate)); + VariableReferenceExpression newVariable = symbolAllocator.newVariable(rewritten, analysis.getType(aggregate)); // TODO: this is a hack, because we apply coercions to the output of expressions, rather than the arguments to expressions. // Therefore we can end up with this implicit cast, and have to move it into a post-projection @@ -558,19 +557,19 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) rewritten = ((Cast) rewritten).getExpression(); needPostProjectionCoercion = true; } - aggregationTranslations.put(aggregate, newSymbol); + aggregationTranslations.put(aggregate, newVariable); FunctionCall rewrittenFunction = (FunctionCall) rewritten; - aggregationsBuilder.put(newSymbol, + aggregationsBuilder.put(newVariable, new Aggregation( analysis.getFunctionHandle(aggregate), rewrittenFunction.getArguments(), rewrittenFunction.getFilter(), - rewrittenFunction.getOrderBy().map(OrderBy::getSortItems).map(PlannerUtils::toOrderingScheme), + rewrittenFunction.getOrderBy().map(OrderBy::getSortItems).map(sortItems -> toOrderingScheme(sortItems, symbolAllocator.getTypes())), rewrittenFunction.isDistinct(), Optional.empty())); } - Map aggregations = aggregationsBuilder.build(); + Map aggregations = aggregationsBuilder.build(); ImmutableSet.Builder globalGroupingSets = ImmutableSet.builder(); for (int i = 0; i < groupingSets.size(); i++) { @@ -579,12 +578,12 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) } } - ImmutableList.Builder groupingKeys = ImmutableList.builder(); + ImmutableList.Builder groupingKeys = ImmutableList.builder(); groupingSets.stream() .flatMap(List::stream) .distinct() .forEach(groupingKeys::add); - groupIdSymbol.ifPresent(groupingKeys::add); + groupIdVariable.ifPresent(groupingKeys::add); AggregationNode aggregationNode = new AggregationNode( idAllocator.getNextId(), @@ -597,7 +596,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), - groupIdSymbol); + groupIdVariable); subPlan = new PlanBuilder(aggregationTranslations, aggregationNode, analysis.getParameters()); @@ -607,13 +606,13 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) if (needPostProjectionCoercion) { ImmutableList.Builder alreadyCoerced = ImmutableList.builder(); alreadyCoerced.addAll(groupByExpressions); - groupIdSymbol.map(Symbol::toSymbolReference).ifPresent(alreadyCoerced::add); + groupIdVariable.map(variable -> new SymbolReference(variable.getName())).ifPresent(alreadyCoerced::add); subPlan = explicitCoercionFields(subPlan, alreadyCoerced.build(), analysis.getAggregates(node)); } // 4. Project and re-write all grouping functions - return handleGroupingOperations(subPlan, node, groupIdSymbol, columnOnlyGroupingSets); + return handleGroupingOperations(subPlan, node, groupIdVariable, columnOnlyGroupingSets); } private List> enumerateGroupingSets(Analysis.GroupingSetAnalysis groupingSetAnalysis) @@ -663,7 +662,7 @@ private List> enumerateGroupingSets(Analysis.GroupingSetAnalysis gr return allSets; } - private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecification node, Optional groupIdSymbol, List> groupingSets) + private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecification node, Optional groupIdVariable, List> groupingSets) { if (analysis.getGroupingOperations(node).isEmpty()) { return subPlan; @@ -672,7 +671,7 @@ private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecifica TranslationMap newTranslations = subPlan.copyTranslations(); Assignments.Builder projections = Assignments.builder(); - projections.putIdentities(subPlan.getRoot().getOutputSymbols()); + projections.putIdentities(subPlan.getRoot().getOutputVariables()); List> descriptor = groupingSets.stream() .map(set -> set.stream() @@ -681,9 +680,9 @@ private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecifica .collect(toImmutableList()); for (GroupingOperation groupingOperation : analysis.getGroupingOperations(node)) { - Expression rewritten = GroupingOperationRewriter.rewriteGroupingOperation(groupingOperation, descriptor, analysis.getColumnReferenceFields(), groupIdSymbol); + Expression rewritten = GroupingOperationRewriter.rewriteGroupingOperation(groupingOperation, descriptor, analysis.getColumnReferenceFields(), groupIdVariable); Type coercion = analysis.getCoercion(groupingOperation); - Symbol symbol = symbolAllocator.newSymbol(rewritten, analysis.getTypeWithCoercions(groupingOperation)); + VariableReferenceExpression variable = symbolAllocator.newVariable(rewritten, analysis.getTypeWithCoercions(groupingOperation)); if (coercion != null) { rewritten = new Cast( rewritten, @@ -691,8 +690,8 @@ private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecifica false, metadata.getTypeManager().isTypeOnlyCoercion(analysis.getType(groupingOperation), coercion)); } - projections.put(symbol, rewritten); - newTranslations.put(groupingOperation, symbol); + projections.put(variable, rewritten); + newTranslations.put(groupingOperation, variable); } return new PlanBuilder(newTranslations, new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), projections.build()), analysis.getParameters()); @@ -753,35 +752,35 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio subPlan = subPlan.appendProjections(inputs.build(), symbolAllocator, idAllocator); // Rewrite PARTITION BY in terms of pre-projected inputs - ImmutableList.Builder partitionBySymbols = ImmutableList.builder(); + ImmutableList.Builder partitionByVariables = ImmutableList.builder(); for (Expression expression : window.getPartitionBy()) { - partitionBySymbols.add(subPlan.translate(expression)); + partitionByVariables.add(subPlan.translateToVariable(expression)); } // Rewrite ORDER BY in terms of pre-projected inputs - LinkedHashMap orderings = new LinkedHashMap<>(); + LinkedHashMap orderings = new LinkedHashMap<>(); for (SortItem item : getSortItemsFromOrderBy(window.getOrderBy())) { - Symbol symbol = subPlan.translate(item.getSortKey()); + VariableReferenceExpression variable = subPlan.translateToVariable(item.getSortKey()); // don't override existing keys, i.e. when "ORDER BY a ASC, a DESC" is specified - orderings.putIfAbsent(symbol, toSortOrder(item)); + orderings.putIfAbsent(variable, toSortOrder(item)); } // Rewrite frame bounds in terms of pre-projected inputs - Optional frameStartSymbol = Optional.empty(); - Optional frameEndSymbol = Optional.empty(); + Optional frameStartVariable = Optional.empty(); + Optional frameEndVariable = Optional.empty(); if (frameStart != null) { - frameStartSymbol = Optional.of(subPlan.translate(frameStart)); + frameStartVariable = Optional.of(subPlan.translate(frameStart)); } if (frameEnd != null) { - frameEndSymbol = Optional.of(subPlan.translate(frameEnd)); + frameEndVariable = Optional.of(subPlan.translate(frameEnd)); } WindowNode.Frame frame = new WindowNode.Frame( toWindowType(frameType), toBoundType(frameStartType), - frameStartSymbol, + frameStartVariable, toBoundType(frameEndType), - frameEndSymbol, + frameEndVariable, Optional.ofNullable(frameStart).map(Expression::toString), Optional.ofNullable(frameEnd).map(Expression::toString)); @@ -799,15 +798,15 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio // If refers to existing symbol, don't create another PlanNode if (rewritten instanceof SymbolReference) { if (needCoercion) { - subPlan = explicitCoercionSymbols(subPlan, subPlan.getRoot().getOutputSymbols(), ImmutableList.of(windowFunction)); + subPlan = explicitCoercionSymbols(subPlan, subPlan.getRoot().getOutputVariables(), ImmutableList.of(windowFunction)); } continue; } Type returnType = analysis.getType(windowFunction); - Symbol newSymbol = symbolAllocator.newSymbol(rewritten, returnType); - outputTranslations.put(windowFunction, newSymbol); + VariableReferenceExpression newVariable = symbolAllocator.newVariable(rewritten, returnType); + outputTranslations.put(windowFunction, newVariable); // TODO: replace arguments with RowExpression once we introduce subquery expression for RowExpression (#12745). // Wrap all arguments in CallExpression to be RawExpression. @@ -822,12 +821,11 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio ((FunctionCall) rewritten).getArguments().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList())), frame); - List sourceSymbols = subPlan.getRoot().getOutputSymbols(); - ImmutableList.Builder orderBySymbols = ImmutableList.builder(); - orderBySymbols.addAll(orderings.keySet()); + ImmutableList.Builder orderByVariables = ImmutableList.builder(); + orderByVariables.addAll(orderings.keySet()); Optional orderingScheme = Optional.empty(); if (!orderings.isEmpty()) { - orderingScheme = Optional.of(new OrderingScheme(orderBySymbols.build(), orderings)); + orderingScheme = Optional.of(new OrderingScheme(orderByVariables.build(), orderings)); } // create window node @@ -836,16 +834,16 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio idAllocator.getNextId(), subPlan.getRoot(), new WindowNode.Specification( - partitionBySymbols.build(), + partitionByVariables.build(), orderingScheme), - ImmutableMap.of(newSymbol, function), + ImmutableMap.of(newVariable, function), Optional.empty(), ImmutableSet.of(), 0), analysis.getParameters()); if (needCoercion) { - subPlan = explicitCoercionSymbols(subPlan, sourceSymbols, ImmutableList.of(windowFunction)); + subPlan = explicitCoercionSymbols(subPlan, subPlan.getRoot().getOutputVariables(), ImmutableList.of(windowFunction)); } } @@ -868,7 +866,7 @@ private PlanBuilder distinct(PlanBuilder subPlan, QuerySpecification node) idAllocator.getNextId(), subPlan.getRoot(), ImmutableMap.of(), - singleGroupingSet(subPlan.getRoot().getOutputSymbols()), + singleGroupingSet(subPlan.getRoot().getOutputVariables()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), @@ -897,20 +895,20 @@ private PlanBuilder sort(PlanBuilder subPlan, Optional orderBy, Optiona Iterator sortItems = orderBy.get().getSortItems().iterator(); // This logic is similar to PlannerUtils::toOrderingScheme - ImmutableList.Builder orderBySymbols = ImmutableList.builder(); - Map orderings = new HashMap<>(); + ImmutableList.Builder orderByVariables = ImmutableList.builder(); + Map orderings = new HashMap<>(); for (Expression fieldOrExpression : orderByExpressions) { - Symbol symbol = subPlan.translate(fieldOrExpression); + VariableReferenceExpression variable = subPlan.translateToVariable(fieldOrExpression); SortItem sortItem = sortItems.next(); - if (!orderings.containsKey(symbol)) { - orderBySymbols.add(symbol); - orderings.put(symbol, toSortOrder(sortItem)); + if (!orderings.containsKey(variable)) { + orderByVariables.add(variable); + orderings.put(variable, toSortOrder(sortItem)); } } PlanNode planNode; - OrderingScheme orderingScheme = new OrderingScheme(orderBySymbols.build(), orderings); + OrderingScheme orderingScheme = new OrderingScheme(orderByVariables.build(), orderings); if (limit.isPresent() && !limit.get().equalsIgnoreCase("all")) { planNode = new TopNNode(idAllocator.getNextId(), subPlan.getRoot(), Long.parseLong(limit.get()), orderingScheme, TopNNode.Step.SINGLE); } @@ -943,14 +941,14 @@ private PlanBuilder limit(PlanBuilder subPlan, Optional orderBy, Option return subPlan; } - private static List toSymbolReferences(List symbols) + private static List toSymbolReferences(List variables) { - return symbols.stream() - .map(Symbol::toSymbolReference) + return variables.stream() + .map(variable -> new SymbolReference(variable.getName())) .collect(toImmutableList()); } - private static Map symbolsForExpressions(PlanBuilder builder, Iterable expressions) + private static Map variablesForExpressions(PlanBuilder builder, Iterable expressions) { return stream(expressions) .distinct() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlan.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlan.java index da37da5fab72e..de23e2d9a9e92 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlan.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlan.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.RelationType; import com.facebook.presto.sql.analyzer.Scope; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -38,10 +39,10 @@ class RelationPlan { private final PlanNode root; - private final List fieldMappings; // for each field in the relation, the corresponding symbol from "root" + private final List fieldMappings; // for each field in the relation, the corresponding variable from "root" private final Scope scope; - public RelationPlan(PlanNode root, Scope scope, List fieldMappings) + public RelationPlan(PlanNode root, Scope scope, List fieldMappings) { requireNonNull(root, "root is null"); requireNonNull(fieldMappings, "outputSymbols is null"); @@ -59,6 +60,12 @@ public RelationPlan(PlanNode root, Scope scope, List fieldMappings) } public Symbol getSymbol(int fieldIndex) + { + checkArgument(fieldIndex >= 0 && fieldIndex < fieldMappings.size(), "No field->symbol mapping for field %s", fieldIndex); + return new Symbol(fieldMappings.get(fieldIndex).getName()); + } + + public VariableReferenceExpression getVariable(int fieldIndex) { checkArgument(fieldIndex >= 0 && fieldIndex < fieldMappings.size(), "No field->symbol mapping for field %s", fieldIndex); return fieldMappings.get(fieldIndex); @@ -69,7 +76,7 @@ public PlanNode getRoot() return root; } - public List getFieldMappings() + public List getFieldMappings() { return fieldMappings; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 57922dfbf3b24..e2203616a3e67 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.RowType; @@ -106,7 +107,7 @@ class RelationPlanner private final Analysis analysis; private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; - private final Map, Symbol> lambdaDeclarationToSymbolMap; + private final Map, VariableReferenceExpression> lambdaDeclarationToVariableMap; private final Metadata metadata; private final Session session; private final SubqueryPlanner subqueryPlanner; @@ -115,24 +116,24 @@ class RelationPlanner Analysis analysis, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, - Map, Symbol> lambdaDeclarationToSymbolMap, + Map, VariableReferenceExpression> lambdaDeclarationToVariableMap, Metadata metadata, Session session) { requireNonNull(analysis, "analysis is null"); requireNonNull(symbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); - requireNonNull(lambdaDeclarationToSymbolMap, "lambdaDeclarationToSymbolMap is null"); + requireNonNull(lambdaDeclarationToVariableMap, "lambdaDeclarationToVariableMap is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(session, "session is null"); this.analysis = analysis; this.symbolAllocator = symbolAllocator; this.idAllocator = idAllocator; - this.lambdaDeclarationToSymbolMap = lambdaDeclarationToSymbolMap; + this.lambdaDeclarationToVariableMap = lambdaDeclarationToVariableMap; this.metadata = metadata; this.session = session; - this.subqueryPlanner = new SubqueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, session); + this.subqueryPlanner = new SubqueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session); } @Override @@ -153,18 +154,17 @@ protected RelationPlan visitTable(Table node, Void context) TableHandle handle = analysis.getTableHandle(node); - ImmutableList.Builder outputSymbolsBuilder = ImmutableList.builder(); - ImmutableMap.Builder columns = ImmutableMap.builder(); + ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); + ImmutableMap.Builder columns = ImmutableMap.builder(); for (Field field : scope.getRelationType().getAllFields()) { - Symbol symbol = symbolAllocator.newSymbol(field.getName().get(), field.getType()); - - outputSymbolsBuilder.add(symbol); - columns.put(symbol, analysis.getColumn(field)); + VariableReferenceExpression variable = symbolAllocator.newVariable(field.getName().get(), field.getType()); + outputVariablesBuilder.add(variable); + columns.put(variable, analysis.getColumn(field)); } - List outputSymbols = outputSymbolsBuilder.build(); - PlanNode root = new TableScanNode(idAllocator.getNextId(), handle, outputSymbols, columns.build()); - return new RelationPlan(root, scope, outputSymbols); + List outputVariables = outputVariablesBuilder.build(); + PlanNode root = new TableScanNode(idAllocator.getNextId(), handle, outputVariables, columns.build()); + return new RelationPlan(root, scope, outputVariables); } @Override @@ -173,18 +173,18 @@ protected RelationPlan visitAliasedRelation(AliasedRelation node, Void context) RelationPlan subPlan = process(node.getRelation(), context); PlanNode root = subPlan.getRoot(); - List mappings = subPlan.getFieldMappings(); + List mappings = subPlan.getFieldMappings(); if (node.getColumnNames() != null) { - ImmutableList.Builder newMappings = ImmutableList.builder(); + ImmutableList.Builder newMappings = ImmutableList.builder(); Assignments.Builder assignments = Assignments.builder(); // project only the visible columns from the underlying relation for (int i = 0; i < subPlan.getDescriptor().getAllFieldCount(); i++) { Field field = subPlan.getDescriptor().getFieldByIndex(i); if (!field.isHidden()) { - Symbol aliasedColumn = symbolAllocator.newSymbol(field); - assignments.put(aliasedColumn, subPlan.getFieldMappings().get(i).toSymbolReference()); + VariableReferenceExpression aliasedColumn = symbolAllocator.newVariable(field); + assignments.put(aliasedColumn, (new Symbol(subPlan.getFieldMappings().get(i).getName())).toSymbolReference()); newMappings.add(aliasedColumn); } } @@ -241,7 +241,7 @@ protected RelationPlan visitJoin(Join node, Void context) PlanBuilder rightPlanBuilder = initializePlanBuilder(rightPlan); // NOTE: symbols must be in the same order as the outputDescriptor - List outputSymbols = ImmutableList.builder() + List outputs = ImmutableList.builder() .addAll(leftPlan.getFieldMappings()) .addAll(rightPlan.getFieldMappings()) .build(); @@ -311,10 +311,10 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende for (int i = 0; i < leftComparisonExpressions.size(); i++) { if (joinConditionComparisonOperators.get(i) == ComparisonExpression.Operator.EQUAL) { - Symbol leftSymbol = leftPlanBuilder.translate(leftComparisonExpressions.get(i)); - Symbol rightSymbol = rightPlanBuilder.translate(rightComparisonExpressions.get(i)); + VariableReferenceExpression leftVariable = leftPlanBuilder.translateToVariable(leftComparisonExpressions.get(i)); + VariableReferenceExpression righVariable = rightPlanBuilder.translateToVariable(rightComparisonExpressions.get(i)); - equiClauses.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol)); + equiClauses.add(new JoinNode.EquiJoinClause(leftVariable, righVariable)); } else { Expression leftExpression = leftPlanBuilder.rewrite(leftComparisonExpressions.get(i)); @@ -329,9 +329,9 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende leftPlanBuilder.getRoot(), rightPlanBuilder.getRoot(), equiClauses.build(), - ImmutableList.builder() - .addAll(leftPlanBuilder.getRoot().getOutputSymbols()) - .addAll(rightPlanBuilder.getRoot().getOutputSymbols()) + ImmutableList.builder() + .addAll(leftPlanBuilder.getRoot().getOutputVariables()) + .addAll(rightPlanBuilder.getRoot().getOutputVariables()) .build(), Optional.empty(), Optional.empty(), @@ -351,9 +351,9 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende leftPlanBuilder = subqueryPlanner.handleUncorrelatedSubqueries(leftPlanBuilder, complexJoinExpressions, node); } - RelationPlan intermediateRootRelationPlan = new RelationPlan(root, analysis.getScope(node), outputSymbols); - TranslationMap translationMap = new TranslationMap(intermediateRootRelationPlan, analysis, lambdaDeclarationToSymbolMap); - translationMap.setFieldMappings(outputSymbols); + RelationPlan intermediateRootRelationPlan = new RelationPlan(root, analysis.getScope(node), outputs); + TranslationMap translationMap = new TranslationMap(intermediateRootRelationPlan, analysis, lambdaDeclarationToVariableMap); + translationMap.setFieldMappings(outputs); translationMap.putExpressionMappingsFrom(leftPlanBuilder.getTranslations()); translationMap.putExpressionMappingsFrom(rightPlanBuilder.getTranslations()); @@ -365,9 +365,9 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende leftPlanBuilder.getRoot(), rightPlanBuilder.getRoot(), equiClauses.build(), - ImmutableList.builder() - .addAll(leftPlanBuilder.getRoot().getOutputSymbols()) - .addAll(rightPlanBuilder.getRoot().getOutputSymbols()) + ImmutableList.builder() + .addAll(leftPlanBuilder.getRoot().getOutputVariables()) + .addAll(rightPlanBuilder.getRoot().getOutputVariables()) .build(), Optional.of(castToRowExpression(rewrittenFilterCondition)), Optional.empty(), @@ -392,7 +392,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende } } - return new RelationPlan(root, analysis.getScope(node), outputSymbols); + return new RelationPlan(root, analysis.getScope(node), outputs); } private RelationPlan planJoinUsing(Join node, RelationPlan left, RelationPlan right) @@ -428,20 +428,20 @@ If casts are redundant (due to column type and common type being equal), ImmutableList.Builder clauses = ImmutableList.builder(); - Map leftJoinColumns = new HashMap<>(); - Map rightJoinColumns = new HashMap<>(); + Map leftJoinColumns = new HashMap<>(); + Map rightJoinColumns = new HashMap<>(); Assignments.Builder leftCoercions = Assignments.builder(); Assignments.Builder rightCoercions = Assignments.builder(); - leftCoercions.putIdentities(left.getRoot().getOutputSymbols()); - rightCoercions.putIdentities(right.getRoot().getOutputSymbols()); + leftCoercions.putIdentities(left.getRoot().getOutputVariables()); + rightCoercions.putIdentities(right.getRoot().getOutputVariables()); for (int i = 0; i < joinColumns.size(); i++) { Identifier identifier = joinColumns.get(i); Type type = analysis.getType(identifier); // compute the coercion for the field on the left to the common supertype of left & right - Symbol leftOutput = symbolAllocator.newSymbol(identifier, type); + VariableReferenceExpression leftOutput = symbolAllocator.newVariable(identifier, type); int leftField = joinAnalysis.getLeftJoinFields().get(i); leftCoercions.put(leftOutput, new Cast( left.getSymbol(leftField).toSymbolReference(), @@ -451,7 +451,7 @@ If casts are redundant (due to column type and common type being equal), leftJoinColumns.put(identifier, leftOutput); // compute the coercion for the field on the right to the common supertype of left & right - Symbol rightOutput = symbolAllocator.newSymbol(identifier, type); + VariableReferenceExpression rightOutput = symbolAllocator.newVariable(identifier, type); int rightField = joinAnalysis.getRightJoinFields().get(i); rightCoercions.put(rightOutput, new Cast( right.getSymbol(rightField).toSymbolReference(), @@ -472,9 +472,9 @@ If casts are redundant (due to column type and common type being equal), leftCoercion, rightCoercion, clauses.build(), - ImmutableList.builder() - .addAll(leftCoercion.getOutputSymbols()) - .addAll(rightCoercion.getOutputSymbols()) + ImmutableList.builder() + .addAll(leftCoercion.getOutputVariables()) + .addAll(rightCoercion.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), @@ -485,25 +485,25 @@ If casts are redundant (due to column type and common type being equal), // which are defined as coalesce(l.k, r.k) Assignments.Builder assignments = Assignments.builder(); - ImmutableList.Builder outputs = ImmutableList.builder(); + ImmutableList.Builder outputs = ImmutableList.builder(); for (Identifier column : joinColumns) { - Symbol output = symbolAllocator.newSymbol(column, analysis.getType(column)); + VariableReferenceExpression output = symbolAllocator.newVariable(column, analysis.getType(column)); outputs.add(output); assignments.put(output, new CoalesceExpression( - leftJoinColumns.get(column).toSymbolReference(), - rightJoinColumns.get(column).toSymbolReference())); + new SymbolReference(leftJoinColumns.get(column).getName()), + new SymbolReference(rightJoinColumns.get(column).getName()))); } for (int field : joinAnalysis.getOtherLeftFields()) { - Symbol symbol = left.getFieldMappings().get(field); - outputs.add(symbol); - assignments.put(symbol, symbol.toSymbolReference()); + VariableReferenceExpression variable = left.getFieldMappings().get(field); + outputs.add(variable); + assignments.put(variable, new SymbolReference(variable.getName())); } for (int field : joinAnalysis.getOtherRightFields()) { - Symbol symbol = right.getFieldMappings().get(field); - outputs.add(symbol); - assignments.put(symbol, symbol.toSymbolReference()); + VariableReferenceExpression variable = right.getFieldMappings().get(field); + outputs.add(variable); + assignments.put(variable, new SymbolReference(variable.getName())); } return new RelationPlan( @@ -542,11 +542,11 @@ private RelationPlan planLateralJoin(Join join, RelationPlan leftPlan, Lateral l PlanBuilder planBuilder = subqueryPlanner.appendLateralJoin(leftPlanBuilder, rightPlanBuilder, lateral.getQuery(), true, LateralJoinNode.Type.INNER); - List outputSymbols = ImmutableList.builder() - .addAll(leftPlan.getRoot().getOutputSymbols()) - .addAll(rightPlan.getRoot().getOutputSymbols()) + List outputVariables = ImmutableList.builder() + .addAll(leftPlan.getRoot().getOutputVariables()) + .addAll(rightPlan.getRoot().getOutputVariables()) .build(); - return new RelationPlan(planBuilder.getRoot(), analysis.getScope(join), outputSymbols); + return new RelationPlan(planBuilder.getRoot(), analysis.getScope(join), outputVariables); } private static boolean isEqualComparisonExpression(Expression conjunct) @@ -558,12 +558,12 @@ private RelationPlan planCrossJoinUnnest(RelationPlan leftPlan, Join joinNode, U { RelationType unnestOutputDescriptor = analysis.getOutputDescriptor(node); // Create symbols for the result of unnesting - ImmutableList.Builder unnestedSymbolsBuilder = ImmutableList.builder(); + ImmutableList.Builder unnestedVariablesBuilder = ImmutableList.builder(); for (Field field : unnestOutputDescriptor.getVisibleFields()) { - Symbol symbol = symbolAllocator.newSymbol(field); - unnestedSymbolsBuilder.add(symbol); + VariableReferenceExpression variable = symbolAllocator.newVariable(field); + unnestedVariablesBuilder.add(variable); } - ImmutableList unnestedSymbols = unnestedSymbolsBuilder.build(); + ImmutableList unnestedVariables = unnestedVariablesBuilder.build(); // Add a projection for all the unnest arguments PlanBuilder planBuilder = initializePlanBuilder(leftPlan); @@ -571,36 +571,36 @@ private RelationPlan planCrossJoinUnnest(RelationPlan leftPlan, Join joinNode, U TranslationMap translations = planBuilder.getTranslations(); ProjectNode projectNode = (ProjectNode) planBuilder.getRoot(); - ImmutableMap.Builder> unnestSymbols = ImmutableMap.builder(); - UnmodifiableIterator unnestedSymbolsIterator = unnestedSymbols.iterator(); + ImmutableMap.Builder> unnestVariables = ImmutableMap.builder(); + UnmodifiableIterator unnestedVariablesIterator = unnestedVariables.iterator(); for (Expression expression : node.getExpressions()) { Type type = analysis.getType(expression); - Symbol inputSymbol = translations.get(expression); + VariableReferenceExpression inputVariable = new VariableReferenceExpression(translations.get(expression).getName(), type); if (type instanceof ArrayType) { Type elementType = ((ArrayType) type).getElementType(); if (!SystemSessionProperties.isLegacyUnnest(session) && elementType instanceof RowType) { - ImmutableList.Builder unnestSymbolBuilder = ImmutableList.builder(); + ImmutableList.Builder unnestVariableBuilder = ImmutableList.builder(); for (int i = 0; i < ((RowType) elementType).getFields().size(); i++) { - unnestSymbolBuilder.add(unnestedSymbolsIterator.next()); + unnestVariableBuilder.add(unnestedVariablesIterator.next()); } - unnestSymbols.put(inputSymbol, unnestSymbolBuilder.build()); + unnestVariables.put(inputVariable, unnestVariableBuilder.build()); } else { - unnestSymbols.put(inputSymbol, ImmutableList.of(unnestedSymbolsIterator.next())); + unnestVariables.put(inputVariable, ImmutableList.of(unnestedVariablesIterator.next())); } } else if (type instanceof MapType) { - unnestSymbols.put(inputSymbol, ImmutableList.of(unnestedSymbolsIterator.next(), unnestedSymbolsIterator.next())); + unnestVariables.put(inputVariable, ImmutableList.of(unnestedVariablesIterator.next(), unnestedVariablesIterator.next())); } else { throw new IllegalArgumentException("Unsupported type for UNNEST: " + type); } } - Optional ordinalitySymbol = node.isWithOrdinality() ? Optional.of(unnestedSymbolsIterator.next()) : Optional.empty(); - checkState(!unnestedSymbolsIterator.hasNext(), "Not all output symbols were matched with input symbols"); + Optional ordinalityVariable = node.isWithOrdinality() ? Optional.of(unnestedVariablesIterator.next()) : Optional.empty(); + checkState(!unnestedVariablesIterator.hasNext(), "Not all output symbols were matched with input symbols"); - UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), projectNode, leftPlan.getFieldMappings(), unnestSymbols.build(), ordinalitySymbol); - return new RelationPlan(unnestNode, analysis.getScope(joinNode), unnestNode.getOutputSymbols()); + UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), projectNode, leftPlan.getFieldMappings(), unnestVariables.build(), ordinalityVariable); + return new RelationPlan(unnestNode, analysis.getScope(joinNode), unnestNode.getOutputVariables()); } @Override @@ -612,14 +612,14 @@ protected RelationPlan visitTableSubquery(TableSubquery node, Void context) @Override protected RelationPlan visitQuery(Query node, Void context) { - return new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, session) + return new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session) .plan(node); } @Override protected RelationPlan visitQuerySpecification(QuerySpecification node, Void context) { - return new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, session) + return new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session) .plan(node); } @@ -628,9 +628,11 @@ protected RelationPlan visitValues(Values node, Void context) { Scope scope = analysis.getScope(node); ImmutableList.Builder outputSymbolsBuilder = ImmutableList.builder(); + ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); for (Field field : scope.getRelationType().getVisibleFields()) { Symbol symbol = symbolAllocator.newSymbol(field); outputSymbolsBuilder.add(symbol); + outputVariablesBuilder.add(new VariableReferenceExpression(symbol.getName(), field.getType())); } ImmutableList.Builder> rowsBuilder = ImmutableList.builder(); @@ -650,59 +652,61 @@ protected RelationPlan visitValues(Values node, Void context) rowsBuilder.add(values.build()); } - ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), outputSymbolsBuilder.build(), rowsBuilder.build()); - return new RelationPlan(valuesNode, scope, outputSymbolsBuilder.build()); + ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), outputVariablesBuilder.build(), rowsBuilder.build()); + return new RelationPlan(valuesNode, scope, outputVariablesBuilder.build()); } @Override protected RelationPlan visitUnnest(Unnest node, Void context) { Scope scope = analysis.getScope(node); - ImmutableList.Builder outputSymbolsBuilder = ImmutableList.builder(); + ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); for (Field field : scope.getRelationType().getVisibleFields()) { - Symbol symbol = symbolAllocator.newSymbol(field); - outputSymbolsBuilder.add(symbol); + VariableReferenceExpression variable = symbolAllocator.newVariable(field); + outputVariablesBuilder.add(variable); } - List unnestedSymbols = outputSymbolsBuilder.build(); + List unnestedVariables = outputVariablesBuilder.build(); // If we got here, then we must be unnesting a constant, and not be in a join (where there could be column references) - ImmutableList.Builder argumentSymbols = ImmutableList.builder(); + ImmutableList.Builder argumentVariables = ImmutableList.builder(); ImmutableList.Builder values = ImmutableList.builder(); - ImmutableMap.Builder> unnestSymbols = ImmutableMap.builder(); - Iterator unnestedSymbolsIterator = unnestedSymbols.iterator(); + ImmutableMap.Builder> unnestVariables = ImmutableMap.builder(); + Iterator unnestedVariablesIterator = unnestedVariables.iterator(); for (Expression expression : node.getExpressions()) { Type type = analysis.getType(expression); Expression rewritten = Coercer.addCoercions(expression, analysis); rewritten = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), rewritten); values.add(castToRowExpression(rewritten)); - Symbol inputSymbol = symbolAllocator.newSymbol(rewritten, type); - argumentSymbols.add(inputSymbol); + VariableReferenceExpression input = symbolAllocator.newVariable(rewritten, type); + argumentVariables.add(new VariableReferenceExpression(input.getName(), type)); if (type instanceof ArrayType) { Type elementType = ((ArrayType) type).getElementType(); if (!SystemSessionProperties.isLegacyUnnest(session) && elementType instanceof RowType) { - ImmutableList.Builder unnestSymbolBuilder = ImmutableList.builder(); + ImmutableList.Builder unnestVariableBuilder = ImmutableList.builder(); for (int i = 0; i < ((RowType) elementType).getFields().size(); i++) { - unnestSymbolBuilder.add(unnestedSymbolsIterator.next()); + unnestVariableBuilder.add(unnestedVariablesIterator.next()); } - unnestSymbols.put(inputSymbol, unnestSymbolBuilder.build()); + unnestVariables.put(input, unnestVariableBuilder.build()); } else { - unnestSymbols.put(inputSymbol, ImmutableList.of(unnestedSymbolsIterator.next())); + unnestVariables.put(input, ImmutableList.of(unnestedVariablesIterator.next())); } } else if (type instanceof MapType) { - unnestSymbols.put(inputSymbol, ImmutableList.of(unnestedSymbolsIterator.next(), unnestedSymbolsIterator.next())); + unnestVariables.put(input, ImmutableList.of(unnestedVariablesIterator.next(), unnestedVariablesIterator.next())); } else { throw new IllegalArgumentException("Unsupported type for UNNEST: " + type); } } - Optional ordinalitySymbol = node.isWithOrdinality() ? Optional.of(unnestedSymbolsIterator.next()) : Optional.empty(); - checkState(!unnestedSymbolsIterator.hasNext(), "Not all output symbols were matched with input symbols"); - ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), argumentSymbols.build(), ImmutableList.of(values.build())); + Optional ordinalityVariable = node.isWithOrdinality() ? Optional.of(unnestedVariablesIterator.next()) : Optional.empty(); + checkState(!unnestedVariablesIterator.hasNext(), "Not all output symbols were matched with input symbols"); + ValuesNode valuesNode = new ValuesNode( + idAllocator.getNextId(), + argumentVariables.build(), ImmutableList.of(values.build())); - UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), valuesNode, ImmutableList.of(), unnestSymbols.build(), ordinalitySymbol); - return new RelationPlan(unnestNode, scope, unnestedSymbols); + UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), valuesNode, ImmutableList.of(), unnestVariables.build(), ordinalityVariable); + return new RelationPlan(unnestNode, scope, unnestedVariables); } private RelationPlan processAndCoerceIfNecessary(Relation node, Void context) @@ -720,27 +724,27 @@ private RelationPlan processAndCoerceIfNecessary(Relation node, Void context) private RelationPlan addCoercions(RelationPlan plan, Type[] targetColumnTypes) { - List oldSymbols = plan.getFieldMappings(); + List oldVariables = plan.getFieldMappings(); RelationType oldDescriptor = plan.getDescriptor().withOnlyVisibleFields(); - verify(targetColumnTypes.length == oldSymbols.size()); - ImmutableList.Builder newSymbols = new ImmutableList.Builder<>(); + verify(targetColumnTypes.length == oldVariables.size()); + ImmutableList.Builder newVariables = new ImmutableList.Builder<>(); Field[] newFields = new Field[targetColumnTypes.length]; Assignments.Builder assignments = Assignments.builder(); for (int i = 0; i < targetColumnTypes.length; i++) { - Symbol inputSymbol = oldSymbols.get(i); - Type inputType = symbolAllocator.getTypes().get(inputSymbol); + VariableReferenceExpression inputVariable = oldVariables.get(i); + Symbol inputSymbol = new Symbol(inputVariable.getName()); Type outputType = targetColumnTypes[i]; - if (!outputType.equals(inputType)) { + if (!outputType.equals(inputVariable.getType())) { Expression cast = new Cast(inputSymbol.toSymbolReference(), outputType.getTypeSignature().toString()); - Symbol outputSymbol = symbolAllocator.newSymbol(cast, outputType); - assignments.put(outputSymbol, cast); - newSymbols.add(outputSymbol); + VariableReferenceExpression outputVariable = symbolAllocator.newVariable(cast, outputType); + assignments.put(outputVariable, cast); + newVariables.add(outputVariable); } else { SymbolReference symbolReference = inputSymbol.toSymbolReference(); - Symbol outputSymbol = symbolAllocator.newSymbol(symbolReference, outputType); - assignments.put(outputSymbol, symbolReference); - newSymbols.add(outputSymbol); + VariableReferenceExpression outputVariable = symbolAllocator.newVariable(symbolReference, outputType); + assignments.put(outputVariable, symbolReference); + newVariables.add(outputVariable); } Field oldField = oldDescriptor.getFieldByIndex(i); newFields[i] = new Field( @@ -753,7 +757,7 @@ private RelationPlan addCoercions(RelationPlan plan, Type[] targetColumnTypes) oldField.isAliased()); } ProjectNode projectNode = new ProjectNode(idAllocator.getNextId(), plan.getRoot(), assignments.build()); - return new RelationPlan(projectNode, Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(newFields)).build(), newSymbols.build()); + return new RelationPlan(projectNode, Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(newFields)).build(), newVariables.build()); } @Override @@ -763,11 +767,11 @@ protected RelationPlan visitUnion(Union node, Void context) SetOperationPlan setOperationPlan = process(node); - PlanNode planNode = new UnionNode(idAllocator.getNextId(), setOperationPlan.getSources(), setOperationPlan.getSymbolMapping(), ImmutableList.copyOf(setOperationPlan.getSymbolMapping().keySet())); + PlanNode planNode = new UnionNode(idAllocator.getNextId(), setOperationPlan.getSources(), setOperationPlan.getVariableMapping()); if (node.isDistinct()) { planNode = distinct(planNode); } - return new RelationPlan(planNode, analysis.getScope(node), planNode.getOutputSymbols()); + return new RelationPlan(planNode, analysis.getScope(node), planNode.getOutputVariables()); } @Override @@ -777,8 +781,8 @@ protected RelationPlan visitIntersect(Intersect node, Void context) SetOperationPlan setOperationPlan = process(node); - PlanNode planNode = new IntersectNode(idAllocator.getNextId(), setOperationPlan.getSources(), setOperationPlan.getSymbolMapping(), ImmutableList.copyOf(setOperationPlan.getSymbolMapping().keySet())); - return new RelationPlan(planNode, analysis.getScope(node), planNode.getOutputSymbols()); + PlanNode planNode = new IntersectNode(idAllocator.getNextId(), setOperationPlan.getSources(), setOperationPlan.getVariableMapping()); + return new RelationPlan(planNode, analysis.getScope(node), planNode.getOutputVariables()); } @Override @@ -788,56 +792,56 @@ protected RelationPlan visitExcept(Except node, Void context) SetOperationPlan setOperationPlan = process(node); - PlanNode planNode = new ExceptNode(idAllocator.getNextId(), setOperationPlan.getSources(), setOperationPlan.getSymbolMapping(), ImmutableList.copyOf(setOperationPlan.getSymbolMapping().keySet())); - return new RelationPlan(planNode, analysis.getScope(node), planNode.getOutputSymbols()); + PlanNode planNode = new ExceptNode(idAllocator.getNextId(), setOperationPlan.getSources(), setOperationPlan.getVariableMapping()); + return new RelationPlan(planNode, analysis.getScope(node), planNode.getOutputVariables()); } private SetOperationPlan process(SetOperation node) { - List outputs = null; + List outputs = null; ImmutableList.Builder sources = ImmutableList.builder(); - ImmutableListMultimap.Builder symbolMapping = ImmutableListMultimap.builder(); + ImmutableListMultimap.Builder variableMapping = ImmutableListMultimap.builder(); List subPlans = node.getRelations().stream() .map(relation -> processAndCoerceIfNecessary(relation, null)) .collect(toImmutableList()); for (RelationPlan relationPlan : subPlans) { - List childOutputSymbols = relationPlan.getFieldMappings(); + List childOutputVariables = relationPlan.getFieldMappings(); if (outputs == null) { // Use the first Relation to derive output symbol names RelationType descriptor = relationPlan.getDescriptor(); - ImmutableList.Builder outputSymbolBuilder = ImmutableList.builder(); + ImmutableList.Builder outputVariableBuilder = ImmutableList.builder(); for (Field field : descriptor.getVisibleFields()) { int fieldIndex = descriptor.indexOf(field); - Symbol symbol = childOutputSymbols.get(fieldIndex); - outputSymbolBuilder.add(symbolAllocator.newSymbol(symbol.getName(), symbolAllocator.getTypes().get(symbol))); + VariableReferenceExpression variable = childOutputVariables.get(fieldIndex); + outputVariableBuilder.add(symbolAllocator.newVariable(variable)); } - outputs = outputSymbolBuilder.build(); + outputs = outputVariableBuilder.build(); } RelationType descriptor = relationPlan.getDescriptor(); checkArgument(descriptor.getVisibleFieldCount() == outputs.size(), - "Expected relation to have %s symbols but has %s symbols", + "Expected relation to have %s variables but has %s variables", descriptor.getVisibleFieldCount(), outputs.size()); int fieldId = 0; for (Field field : descriptor.getVisibleFields()) { int fieldIndex = descriptor.indexOf(field); - symbolMapping.put(outputs.get(fieldId), childOutputSymbols.get(fieldIndex)); + variableMapping.put(outputs.get(fieldId), childOutputVariables.get(fieldIndex)); fieldId++; } sources.add(relationPlan.getRoot()); } - return new SetOperationPlan(sources.build(), symbolMapping.build()); + return new SetOperationPlan(sources.build(), variableMapping.build()); } private PlanBuilder initializePlanBuilder(RelationPlan relationPlan) { - TranslationMap translations = new TranslationMap(relationPlan, analysis, lambdaDeclarationToSymbolMap); + TranslationMap translations = new TranslationMap(relationPlan, analysis, lambdaDeclarationToVariableMap); // Make field->symbol mapping from underlying relation plan available for translations // This makes it possible to rewrite FieldOrExpressions that reference fields from the underlying tuple directly @@ -851,7 +855,7 @@ private PlanNode distinct(PlanNode node) return new AggregationNode(idAllocator.getNextId(), node, ImmutableMap.of(), - singleGroupingSet(node.getOutputSymbols()), + singleGroupingSet(node.getOutputVariables()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), @@ -861,12 +865,12 @@ private PlanNode distinct(PlanNode node) private static class SetOperationPlan { private final List sources; - private final ListMultimap symbolMapping; + private final ListMultimap variableMapping; - private SetOperationPlan(List sources, ListMultimap symbolMapping) + private SetOperationPlan(List sources, ListMultimap variableMapping) { this.sources = sources; - this.symbolMapping = symbolMapping; + this.variableMapping = variableMapping; } public List getSources() @@ -874,9 +878,22 @@ public List getSources() return sources; } + public ListMultimap getVariableMapping() + { + return variableMapping; + } + public ListMultimap getSymbolMapping() { - return symbolMapping; + ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); + variableMapping.asMap().entrySet().stream() + .forEach(entry -> builder.putAll( + new Symbol(entry.getKey().getName()), + entry.getValue().stream() + .map(VariableReferenceExpression::getName) + .map(Symbol::new) + .collect(toImmutableList()))); + return builder.build(); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java index 0b1a57f62d6d1..df98fb3ea76f2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java @@ -28,6 +28,7 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import java.util.List; import java.util.Optional; @@ -35,7 +36,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Collections.singletonList; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; @@ -65,10 +65,10 @@ by sorting position links according to the result of f(...) function. */ private SortExpressionExtractor() {} - public static Optional extractSortExpression(Set buildSymbols, RowExpression filter, FunctionManager functionManager) + public static Optional extractSortExpression(Set buildVariables, RowExpression filter, FunctionManager functionManager) { List filterConjuncts = LogicalRowExpressions.extractConjuncts(filter); - SortExpressionVisitor visitor = new SortExpressionVisitor(buildSymbols, functionManager); + SortExpressionVisitor visitor = new SortExpressionVisitor(buildVariables, functionManager); DeterminismEvaluator determinismEvaluator = new RowExpressionDeterminismEvaluator(functionManager); List sortExpressionCandidates = filterConjuncts.stream() @@ -100,12 +100,12 @@ private static SortExpressionContext merge(SortExpressionContext left, SortExpre private static class SortExpressionVisitor implements RowExpressionVisitor, Void> { - private final Set buildSymbols; + private final Set buildVariables; private final FunctionManager functionManager; - public SortExpressionVisitor(Set buildSymbols, FunctionManager functionManager) + public SortExpressionVisitor(Set buildVariables, FunctionManager functionManager) { - this.buildSymbols = buildSymbols; + this.buildVariables = buildVariables; this.functionManager = functionManager; } @@ -124,11 +124,11 @@ public Optional visitCall(CallExpression call, Void conte case LESS_THAN_OR_EQUAL: RowExpression left = call.getArguments().get(0); RowExpression right = call.getArguments().get(1); - Optional sortChannel = asBuildVariableReference(buildSymbols, right); - boolean hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, left); + Optional sortChannel = asBuildVariableReference(buildVariables, right); + boolean hasBuildReferencesOnOtherSide = hasBuildVariableReference(buildVariables, left); if (!sortChannel.isPresent()) { - sortChannel = asBuildVariableReference(buildSymbols, left); - hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, right); + sortChannel = asBuildVariableReference(buildVariables, left); + hasBuildReferencesOnOtherSide = hasBuildVariableReference(buildVariables, right); } if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) { return sortChannel.map(variableReference -> new SortExpressionContext(variableReference, singletonList(call))); @@ -170,33 +170,31 @@ public Optional visitSpecialForm(SpecialFormExpression sp } } - private static Optional asBuildVariableReference(Set buildLayout, RowExpression expression) + private static Optional asBuildVariableReference(Set buildLayout, RowExpression expression) { // Currently only we support only symbol as sort expression on build side if (expression instanceof VariableReferenceExpression) { VariableReferenceExpression reference = (VariableReferenceExpression) expression; - if (buildLayout.contains(new Symbol(reference.getName()))) { + if (buildLayout.contains(reference)) { return Optional.of(reference); } } return Optional.empty(); } - private static boolean hasBuildSymbolReference(Set buildSymbols, RowExpression expression) + private static boolean hasBuildVariableReference(Set buildVariables, RowExpression expression) { - return expression.accept(new BuildSymbolReferenceFinder(buildSymbols), null); + return expression.accept(new BuildVariableReferenceFinder(buildVariables), null); } - private static class BuildSymbolReferenceFinder + private static class BuildVariableReferenceFinder implements RowExpressionVisitor { - private final Set buildSymbols; + private final Set buildVariables; - public BuildSymbolReferenceFinder(Set buildSymbols) + public BuildVariableReferenceFinder(Set buildVariables) { - this.buildSymbols = requireNonNull(buildSymbols, "buildSymbols is null").stream() - .map(Symbol::getName) - .collect(toImmutableSet()); + this.buildVariables = ImmutableSet.copyOf(requireNonNull(buildVariables, "buildVariables is null")); } @Override @@ -231,7 +229,7 @@ public Boolean visitLambda(LambdaDefinitionExpression lambda, Void context) @Override public Boolean visitVariableReference(VariableReferenceExpression reference, Void context) { - return buildSymbols.contains(reference.getName()); + return buildVariables.contains(reference); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java index 4f359aacf1a63..07f2144f5d2e4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.sql.planner; -import com.facebook.presto.Session; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.operator.aggregation.MaxDataSizeForStats; import com.facebook.presto.operator.aggregation.SumDataSizeForStats; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; import com.facebook.presto.spi.statistics.ColumnStatisticType; import com.facebook.presto.spi.statistics.TableStatisticType; @@ -48,31 +48,29 @@ public class StatisticsAggregationPlanner { - private final Session session; private final SymbolAllocator symbolAllocator; private final Metadata metadata; - public StatisticsAggregationPlanner(Session session, SymbolAllocator symbolAllocator, Metadata metadata) + public StatisticsAggregationPlanner(SymbolAllocator symbolAllocator, Metadata metadata) { - this.session = requireNonNull(session, "session is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); } - public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMetadata statisticsMetadata, Map columnToSymbolMap) + public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMetadata statisticsMetadata, Map columnToVariableMap) { - StatisticAggregationsDescriptor.Builder descriptor = StatisticAggregationsDescriptor.builder(); + StatisticAggregationsDescriptor.Builder descriptor = StatisticAggregationsDescriptor.builder(); List groupingColumns = statisticsMetadata.getGroupingColumns(); - List groupingSymbols = groupingColumns.stream() - .map(columnToSymbolMap::get) + List groupingVariables = groupingColumns.stream() + .map(columnToVariableMap::get) .collect(toImmutableList()); - for (int i = 0; i < groupingSymbols.size(); i++) { - descriptor.addGrouping(groupingColumns.get(i), groupingSymbols.get(i)); + for (int i = 0; i < groupingVariables.size(); i++) { + descriptor.addGrouping(groupingColumns.get(i), groupingVariables.get(i)); } - ImmutableMap.Builder aggregations = ImmutableMap.builder(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); FunctionManager functionManager = metadata.getFunctionManager(); for (TableStatisticType type : statisticsMetadata.getTableStatistics()) { if (type != ROW_COUNT) { @@ -86,45 +84,44 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta Optional.empty(), false, Optional.empty()); - Symbol symbol = symbolAllocator.newSymbol("rowCount", BIGINT); - aggregations.put(symbol, aggregation); - descriptor.addTableStatistic(ROW_COUNT, symbol); + VariableReferenceExpression variable = symbolAllocator.newVariable("rowCount", BIGINT); + aggregations.put(variable, aggregation); + descriptor.addTableStatistic(ROW_COUNT, variable); } for (ColumnStatisticMetadata columnStatisticMetadata : statisticsMetadata.getColumnStatistics()) { String columnName = columnStatisticMetadata.getColumnName(); ColumnStatisticType statisticType = columnStatisticMetadata.getStatisticType(); - Symbol inputSymbol = columnToSymbolMap.get(columnName); - verify(inputSymbol != null, "inputSymbol is null"); - Type inputType = symbolAllocator.getTypes().get(inputSymbol); - verify(inputType != null, "inputType is null for symbol: %s", inputSymbol); - ColumnStatisticsAggregation aggregation = createColumnAggregation(statisticType, inputSymbol, inputType); - Symbol symbol = symbolAllocator.newSymbol(statisticType + ":" + columnName, aggregation.getOutputType()); - aggregations.put(symbol, aggregation.getAggregation()); - descriptor.addColumnStatistic(columnStatisticMetadata, symbol); + VariableReferenceExpression inputVariable = columnToVariableMap.get(columnName); + verify(inputVariable != null, "inputVariable is null"); + ColumnStatisticsAggregation aggregation = createColumnAggregation(statisticType, inputVariable); + VariableReferenceExpression variable = symbolAllocator.newVariable(statisticType + ":" + columnName, aggregation.getOutputType()); + aggregations.put(variable, aggregation.getAggregation()); + descriptor.addColumnStatistic(columnStatisticMetadata, variable); } - StatisticAggregations aggregation = new StatisticAggregations(aggregations.build(), groupingSymbols); + StatisticAggregations aggregation = new StatisticAggregations(aggregations.build(), groupingVariables); return new TableStatisticAggregation(aggregation, descriptor.build()); } - private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticType statisticType, Symbol input, Type inputType) + private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticType statisticType, VariableReferenceExpression input) { + SymbolReference symbolReference = new SymbolReference(input.getName()); switch (statisticType) { case MIN_VALUE: - return createAggregation("min", input.toSymbolReference(), inputType, inputType); + return createAggregation("min", symbolReference, input.getType(), input.getType()); case MAX_VALUE: - return createAggregation("max", input.toSymbolReference(), inputType, inputType); + return createAggregation("max", symbolReference, input.getType(), input.getType()); case NUMBER_OF_DISTINCT_VALUES: - return createAggregation("approx_distinct", input.toSymbolReference(), inputType, BIGINT); + return createAggregation("approx_distinct", symbolReference, input.getType(), BIGINT); case NUMBER_OF_NON_NULL_VALUES: - return createAggregation("count", input.toSymbolReference(), inputType, BIGINT); + return createAggregation("count", symbolReference, input.getType(), BIGINT); case NUMBER_OF_TRUE_VALUES: - return createAggregation("count_if", input.toSymbolReference(), BOOLEAN, BIGINT); + return createAggregation("count_if", symbolReference, BOOLEAN, BIGINT); case TOTAL_SIZE_IN_BYTES: - return createAggregation(SumDataSizeForStats.NAME, input.toSymbolReference(), inputType, BIGINT); + return createAggregation(SumDataSizeForStats.NAME, symbolReference, input.getType(), BIGINT); case MAX_VALUE_SIZE_IN_BYTES: - return createAggregation(MaxDataSizeForStats.NAME, input.toSymbolReference(), inputType, BIGINT); + return createAggregation(MaxDataSizeForStats.NAME, symbolReference, input.getType(), BIGINT); default: throw new IllegalArgumentException("Unsupported statistic type: " + statisticType); } @@ -150,11 +147,11 @@ private ColumnStatisticsAggregation createAggregation(String functionName, Symbo public static class TableStatisticAggregation { private final StatisticAggregations aggregations; - private final StatisticAggregationsDescriptor descriptor; + private final StatisticAggregationsDescriptor descriptor; private TableStatisticAggregation( StatisticAggregations aggregations, - StatisticAggregationsDescriptor descriptor) + StatisticAggregationsDescriptor descriptor) { this.aggregations = requireNonNull(aggregations, "statisticAggregations is null"); this.descriptor = requireNonNull(descriptor, "descriptor is null"); @@ -165,7 +162,7 @@ public StatisticAggregations getAggregations() return aggregations; } - public StatisticAggregationsDescriptor getDescriptor() + public StatisticAggregationsDescriptor getDescriptor() { return descriptor; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java index 9ae6ee9398417..400e16ba3f301 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java @@ -17,6 +17,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -77,7 +78,7 @@ class SubqueryPlanner private final Analysis analysis; private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; - private final Map, Symbol> lambdaDeclarationToSymbolMap; + private final Map, VariableReferenceExpression> lambdaDeclarationToVariableMap; private final Metadata metadata; private final Session session; @@ -85,21 +86,21 @@ class SubqueryPlanner Analysis analysis, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, - Map, Symbol> lambdaDeclarationToSymbolMap, + Map, VariableReferenceExpression> lambdaDeclarationToVariableMap, Metadata metadata, Session session) { requireNonNull(analysis, "analysis is null"); requireNonNull(symbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); - requireNonNull(lambdaDeclarationToSymbolMap, "lambdaDeclarationToSymbolMap is null"); + requireNonNull(lambdaDeclarationToVariableMap, "lambdaDeclarationToVariableMap is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(session, "session is null"); this.analysis = analysis; this.symbolAllocator = symbolAllocator; this.idAllocator = idAllocator; - this.lambdaDeclarationToSymbolMap = lambdaDeclarationToSymbolMap; + this.lambdaDeclarationToVariableMap = lambdaDeclarationToVariableMap; this.metadata = metadata; this.session = session; } @@ -191,15 +192,15 @@ private PlanBuilder appendInPredicateApplyNode(PlanBuilder subPlan, InPredicate PlanBuilder subqueryPlan = createPlanBuilder(uncoercedValueListSubquery); subqueryPlan = subqueryPlan.appendProjections(ImmutableList.of(valueListSubquery), symbolAllocator, idAllocator); - SymbolReference valueList = subqueryPlan.translate(valueListSubquery).toSymbolReference(); + SymbolReference valueList = new SymbolReference(subqueryPlan.translate(valueListSubquery).getName()); - Symbol rewrittenValue = subPlan.translate(inPredicate.getValue()); - InPredicate inPredicateSubqueryExpression = new InPredicate(rewrittenValue.toSymbolReference(), valueList); - Symbol inPredicateSubquerySymbol = symbolAllocator.newSymbol(inPredicateSubqueryExpression, BOOLEAN); + VariableReferenceExpression rewrittenValue = subPlan.translate(inPredicate.getValue()); + InPredicate inPredicateSubqueryExpression = new InPredicate(new SymbolReference(rewrittenValue.getName()), valueList); + VariableReferenceExpression inPredicateSubqueryVariable = symbolAllocator.newVariable(inPredicateSubqueryExpression, BOOLEAN); - subPlan.getTranslations().put(inPredicate, inPredicateSubquerySymbol); + subPlan.getTranslations().put(inPredicate, inPredicateSubqueryVariable); - return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), Assignments.of(inPredicateSubquerySymbol, inPredicateSubqueryExpression), correlationAllowed); + return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), Assignments.of(inPredicateSubqueryVariable, inPredicateSubqueryExpression), correlationAllowed); } private PlanBuilder appendScalarSubqueryApplyNodes(PlanBuilder builder, Set scalarSubqueries, boolean correlationAllowed) @@ -224,12 +225,12 @@ private PlanBuilder appendScalarSubqueryApplyNode(PlanBuilder subPlan, SubqueryE subqueryPlan = subqueryPlan.withNewRoot(new EnforceSingleRowNode(idAllocator.getNextId(), subqueryPlan.getRoot())); subqueryPlan = subqueryPlan.appendProjections(coercions, symbolAllocator, idAllocator); - Symbol uncoercedScalarSubquerySymbol = subqueryPlan.translate(uncoercedScalarSubquery); - subPlan.getTranslations().put(uncoercedScalarSubquery, uncoercedScalarSubquerySymbol); + VariableReferenceExpression uncoercedScalarSubqueryVariable = subqueryPlan.translate(uncoercedScalarSubquery); + subPlan.getTranslations().put(uncoercedScalarSubquery, uncoercedScalarSubqueryVariable); for (Expression coercion : coercions) { - Symbol coercionSymbol = subqueryPlan.translate(coercion); - subPlan.getTranslations().put(coercion, coercionSymbol); + VariableReferenceExpression coercionVariable = subqueryPlan.translate(coercion); + subPlan.getTranslations().put(coercion, coercionVariable); } return appendLateralJoin(subPlan, subqueryPlan, scalarSubquery.getQuery(), correlationAllowed, LateralJoinNode.Type.LEFT); @@ -250,7 +251,7 @@ public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPl idAllocator.getNextId(), subPlan.getRoot(), subqueryNode, - ImmutableList.copyOf(SymbolsExtractor.extractUnique(correlation.values())), + ImmutableList.copyOf(SymbolsExtractor.extractUniqueVariable(correlation.values(), symbolAllocator.getTypes())), type, subQueryNotSupportedError(query, "Given correlated subquery")), analysis.getParameters()); @@ -291,7 +292,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred // add an explicit projection that removes all columns PlanNode subqueryNode = new ProjectNode(idAllocator.getNextId(), subqueryPlan.getRoot(), Assignments.of()); - Symbol exists = symbolAllocator.newSymbol("exists", BOOLEAN); + VariableReferenceExpression exists = symbolAllocator.newVariable("exists", BOOLEAN); subPlan.getTranslations().put(existsPredicate, exists); ExistsPredicate rewrittenExistsPredicate = new ExistsPredicate(BooleanLiteral.TRUE_LITERAL); return appendApplyNode( @@ -386,17 +387,17 @@ private PlanBuilder planQuantifiedApplyNode(PlanBuilder subPlan, QuantifiedCompa QuantifiedComparisonExpression coercedQuantifiedComparison = new QuantifiedComparisonExpression( quantifiedComparison.getOperator(), quantifiedComparison.getQuantifier(), - subPlan.translate(quantifiedComparison.getValue()).toSymbolReference(), - subqueryPlan.translate(quantifiedSubquery).toSymbolReference()); + new SymbolReference(subPlan.translate(quantifiedComparison.getValue()).getName()), + new SymbolReference(subqueryPlan.translate(quantifiedSubquery).getName())); - Symbol coercedQuantifiedComparisonSymbol = symbolAllocator.newSymbol(coercedQuantifiedComparison, BOOLEAN); - subPlan.getTranslations().put(quantifiedComparison, coercedQuantifiedComparisonSymbol); + VariableReferenceExpression coercedQuantifiedComparisonVariable = symbolAllocator.newVariable(coercedQuantifiedComparison, BOOLEAN); + subPlan.getTranslations().put(quantifiedComparison, coercedQuantifiedComparisonVariable); return appendApplyNode( subPlan, quantifiedComparison.getSubquery(), subqueryPlan.getRoot(), - Assignments.of(coercedQuantifiedComparisonSymbol, coercedQuantifiedComparison), + Assignments.of(coercedQuantifiedComparisonVariable, coercedQuantifiedComparison), correlationAllowed); } @@ -445,7 +446,7 @@ private PlanBuilder appendApplyNode(PlanBuilder subPlan, Node subquery, PlanNode root, subqueryNode, subqueryAssignments, - ImmutableList.copyOf(SymbolsExtractor.extractUnique(correlation.values())), + ImmutableList.copyOf(SymbolsExtractor.extractUniqueVariable(correlation.values(), symbolAllocator.getTypes())), subQueryNotSupportedError(subquery, "Given correlated subquery")), analysis.getParameters()); } @@ -477,9 +478,9 @@ private static Optional tryResolveMissingExpression(PlanBuilder subP private PlanBuilder createPlanBuilder(Node node) { - RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, session) + RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session) .process(node, null); - TranslationMap translations = new TranslationMap(relationPlan, analysis, lambdaDeclarationToSymbolMap); + TranslationMap translations = new TranslationMap(relationPlan, analysis, lambdaDeclarationToVariableMap); // Make field->symbol mapping from underlying relation plan available for translations // This makes it possible to rewrite FieldOrExpressions that reference fields from the FROM clause directly @@ -595,7 +596,7 @@ public PlanNode visitValues(ValuesNode node, RewriteContext context) .collect(toImmutableList()); return new ValuesNode( idAllocator.getNextId(), - rewrittenNode.getOutputSymbols(), + rewrittenNode.getOutputVariables(), rewrittenRows); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java index 34f07544b1c21..0f3b7ef0a1fbf 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.Field; @@ -24,10 +25,13 @@ import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.primitives.Ints; +import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; @@ -52,19 +56,46 @@ public Symbol newSymbol(Symbol symbolHint) return newSymbol(symbolHint.getName(), symbols.get(symbolHint)); } + public VariableReferenceExpression newVariable(Symbol symbolHint) + { + checkArgument(symbols.containsKey(symbolHint), "symbolHint not in symbols map"); + return newVariable(symbolHint.getName(), symbols.get(symbolHint)); + } + + public VariableReferenceExpression newVariable(VariableReferenceExpression variableHint) + { + return newVariable(variableHint.getName(), variableHint.getType()); + } + public Symbol newSymbol(QualifiedName nameHint, Type type) { return newSymbol(nameHint.getSuffix(), type, null); } + public VariableReferenceExpression newVariable(QualifiedName nameHint, Type type) + { + return newVariable(nameHint.getSuffix(), type, null); + } + public Symbol newSymbol(String nameHint, Type type) { return newSymbol(nameHint, type, null); } - public Symbol newHashSymbol() + public VariableReferenceExpression newVariable(String nameHint, Type type) + { + return newVariable(nameHint, type, null); + } + + public VariableReferenceExpression newHashVariable() + { + return newVariable("$hashValue", BigintType.BIGINT); + } + + public VariableReferenceExpression newVariable(String nameHint, Type type, String suffix) { - return newSymbol("$hashValue", BigintType.BIGINT); + Symbol symbol = newSymbol(nameHint, type, suffix); + return new VariableReferenceExpression(symbol.getName(), type); } public Symbol newSymbol(String nameHint, Type type, String suffix) @@ -102,6 +133,12 @@ public Symbol newSymbol(String nameHint, Type type, String suffix) return symbol; } + public VariableReferenceExpression newVariable(Expression expression, Type type) + { + Symbol symbol = newSymbol(expression, type); + return new VariableReferenceExpression(symbol.getName(), type); + } + public Symbol newSymbol(Expression expression, Type type) { return newSymbol(expression, type, null); @@ -126,12 +163,35 @@ else if (expression instanceof GroupingOperation) { return newSymbol(nameHint, type, suffix); } + public VariableReferenceExpression newVariable(Expression expression, Type type, String suffix) + { + String nameHint = "expr"; + if (expression instanceof Identifier) { + nameHint = ((Identifier) expression).getValue(); + } + else if (expression instanceof FunctionCall) { + nameHint = ((FunctionCall) expression).getName().getSuffix(); + } + else if (expression instanceof SymbolReference) { + nameHint = ((SymbolReference) expression).getName(); + } + else if (expression instanceof GroupingOperation) { + nameHint = "grouping"; + } + return newVariable(nameHint, type, suffix); + } + public Symbol newSymbol(Field field) { String nameHint = field.getName().orElse("field"); return newSymbol(nameHint, field.getType()); } + public VariableReferenceExpression newVariable(Field field) + { + return newVariable(field.getName().orElse("field"), field.getType(), null); + } + public TypeProvider getTypes() { return TypeProvider.viewOf(symbols); @@ -141,4 +201,16 @@ private int nextId() { return nextId++; } + + public VariableReferenceExpression toVariableReference(Symbol symbol) + { + return new VariableReferenceExpression(symbol.getName(), getTypes().get(symbol)); + } + + public List toVariableReferences(Collection symbols) + { + return symbols.stream() + .map(symbol -> new VariableReferenceExpression(symbol.getName(), getTypes().get(symbol))) + .collect(toImmutableList()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java index 91e531b69bb64..d6473fd89b6c9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java @@ -54,6 +54,14 @@ public static Set extractUnique(PlanNode node) return uniqueSymbols.build(); } + public static Set extractUniqueVariable(PlanNode node, TypeProvider types) + { + ImmutableSet.Builder unique = ImmutableSet.builder(); + extractExpressions(node).forEach(expression -> unique.addAll(extractUniqueVariableInternal(expression, types))); + + return unique.build(); + } + public static Set extractUniqueNonRecursive(PlanNode node) { ImmutableSet.Builder uniqueSymbols = ImmutableSet.builder(); @@ -70,17 +78,34 @@ public static Set extractUnique(PlanNode node, Lookup lookup) return uniqueSymbols.build(); } + public static Set extractUniqueVariable(PlanNode node, Lookup lookup, TypeProvider types) + { + ImmutableSet.Builder unique = ImmutableSet.builder(); + extractExpressions(node, lookup).forEach(expression -> unique.addAll(extractUniqueVariableInternal(expression, types))); + return unique.build(); + } + public static Set extractUnique(Expression expression) { return ImmutableSet.copyOf(extractAll(expression)); } + public static Set extractUniqueVariable(Expression expression, TypeProvider types) + { + return ImmutableSet.copyOf(extractAllVariable(expression, types)); + } + // TODO: return Set public static Set extractUnique(RowExpression expression) { return ImmutableSet.copyOf(extractAll(expression).stream().map(variable -> new Symbol(variable.getName())).collect(toSet())); } + public static Set extractUniqueVariable(RowExpression expression) + { + return ImmutableSet.copyOf(extractAll(expression)); + } + public static Set extractUnique(Iterable expressions) { ImmutableSet.Builder unique = ImmutableSet.builder(); @@ -90,12 +115,27 @@ public static Set extractUnique(Iterable expressio return unique.build(); } + public static Set extractUniqueVariable(Iterable expressions, TypeProvider types) + { + ImmutableSet.Builder unique = ImmutableSet.builder(); + for (Expression expression : expressions) { + unique.addAll(extractAllVariable(expression, types)); + } + return unique.build(); + } + public static List extractAll(Expression expression) { ImmutableList.Builder builder = ImmutableList.builder(); new SymbolBuilderVisitor().process(expression, builder); return builder.build(); } + public static List extractAllVariable(Expression expression, TypeProvider types) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new VariableFromExpressionBuilderVisitor(types).process(expression, builder); + return builder.build(); + } public static List extractAll(RowExpression expression) { @@ -112,20 +152,28 @@ public static Set extractNames(Expression expression, Set extractOutputSymbols(PlanNode planNode) + public static Set extractOutputVariables(PlanNode planNode) { - return extractOutputSymbols(planNode, noLookup()); + return extractOutputVariables(planNode, noLookup()); } - public static Set extractOutputSymbols(PlanNode planNode, Lookup lookup) + public static Set extractOutputVariables(PlanNode planNode, Lookup lookup) { return searchFrom(planNode, lookup) .findAll() .stream() - .flatMap(node -> node.getOutputSymbols().stream()) + .flatMap(node -> node.getOutputVariables().stream()) .collect(toImmutableSet()); } + public static Set extractOutputVariables(PlanNode planNode, Lookup lookup, TypeProvider types) + { + return searchFrom(planNode, lookup) + .findAll() + .stream() + .flatMap(node -> node.getOutputVariables().stream()) + .collect(toImmutableSet()); + } /** * {@param expression} could be an OriginalExpression */ @@ -137,6 +185,14 @@ private static Set extractUniqueInternal(RowExpression expression) return extractUnique(expression); } + private static Set extractUniqueVariableInternal(RowExpression expression, TypeProvider types) + { + if (isExpression(expression)) { + return extractUniqueVariable(castToExpression(expression), types); + } + return extractUniqueVariable(expression); + } + private static class SymbolBuilderVisitor extends DefaultExpressionTraversalVisitor> { @@ -148,6 +204,24 @@ protected Void visitSymbolReference(SymbolReference node, ImmutableList.Builder< } } + private static class VariableFromExpressionBuilderVisitor + extends DefaultExpressionTraversalVisitor> + { + private final TypeProvider types; + + protected VariableFromExpressionBuilderVisitor(TypeProvider types) + { + this.types = types; + } + + @Override + protected Void visitSymbolReference(SymbolReference node, ImmutableList.Builder builder) + { + builder.add(new VariableReferenceExpression(node.getName(), types.get(Symbol.from(node)))); + return null; + } + } + private static class VariableBuilderVisitor extends DefaultRowExpressionTraversalVisitor> { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java index a7fb3e44058f9..811c61b6b5513 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.ResolvedField; @@ -27,6 +28,7 @@ import com.facebook.presto.sql.tree.LambdaExpression; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Parameter; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import java.util.HashMap; @@ -46,22 +48,22 @@ class TranslationMap // all expressions are rewritten in terms of fields declared by this relation plan private final RelationPlan rewriteBase; private final Analysis analysis; - private final Map, Symbol> lambdaDeclarationToSymbolMap; + private final Map, VariableReferenceExpression> lambdaDeclarationToVariableMap; // current mappings of underlying field -> symbol for translating direct field references - private final Symbol[] fieldSymbols; + private final VariableReferenceExpression[] fieldVariables; // current mappings of sub-expressions -> symbol - private final Map expressionToSymbols = new HashMap<>(); + private final Map expressionToVariables = new HashMap<>(); private final Map expressionToExpressions = new HashMap<>(); - public TranslationMap(RelationPlan rewriteBase, Analysis analysis, Map, Symbol> lambdaDeclarationToSymbolMap) + public TranslationMap(RelationPlan rewriteBase, Analysis analysis, Map, VariableReferenceExpression> lambdaDeclarationToVariableMap) { this.rewriteBase = requireNonNull(rewriteBase, "rewriteBase is null"); this.analysis = requireNonNull(analysis, "analysis is null"); - this.lambdaDeclarationToSymbolMap = requireNonNull(lambdaDeclarationToSymbolMap, "lambdaDeclarationToSymbolMap is null"); + this.lambdaDeclarationToVariableMap = requireNonNull(lambdaDeclarationToVariableMap, "lambdaDeclarationToVariableMap is null"); - fieldSymbols = new Symbol[rewriteBase.getFieldMappings().size()]; + fieldVariables = new VariableReferenceExpression[rewriteBase.getFieldMappings().size()]; } public RelationPlan getRelationPlan() @@ -74,35 +76,35 @@ public Analysis getAnalysis() return analysis; } - public Map, Symbol> getLambdaDeclarationToSymbolMap() + public Map, VariableReferenceExpression> getLambdaDeclarationToVariableMap() { - return lambdaDeclarationToSymbolMap; + return lambdaDeclarationToVariableMap; } - public void setFieldMappings(List symbols) + public void setFieldMappings(List variables) { - checkArgument(symbols.size() == fieldSymbols.length, "size of symbols list (%s) doesn't match number of expected fields (%s)", symbols.size(), fieldSymbols.length); + checkArgument(variables.size() == fieldVariables.length, "size of variables list (%s) doesn't match number of expected fields (%s)", variables.size(), fieldVariables.length); - for (int i = 0; i < symbols.size(); i++) { - this.fieldSymbols[i] = symbols.get(i); + for (int i = 0; i < variables.size(); i++) { + this.fieldVariables[i] = variables.get(i); } } public void copyMappingsFrom(TranslationMap other) { - checkArgument(other.fieldSymbols.length == fieldSymbols.length, + checkArgument(other.fieldVariables.length == fieldVariables.length, "number of fields in other (%s) doesn't match number of expected fields (%s)", - other.fieldSymbols.length, - fieldSymbols.length); + other.fieldVariables.length, + fieldVariables.length); - expressionToSymbols.putAll(other.expressionToSymbols); + expressionToVariables.putAll(other.expressionToVariables); expressionToExpressions.putAll(other.expressionToExpressions); - System.arraycopy(other.fieldSymbols, 0, fieldSymbols, 0, other.fieldSymbols.length); + System.arraycopy(other.fieldVariables, 0, fieldVariables, 0, other.fieldVariables.length); } public void putExpressionMappingsFrom(TranslationMap other) { - expressionToSymbols.putAll(other.expressionToSymbols); + expressionToVariables.putAll(other.expressionToVariables); expressionToExpressions.putAll(other.expressionToExpressions); } @@ -117,8 +119,8 @@ public Expression rewrite(Expression expression) @Override public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter treeRewriter) { - if (expressionToSymbols.containsKey(node)) { - return expressionToSymbols.get(node).toSymbolReference(); + if (expressionToVariables.containsKey(node)) { + return new SymbolReference(expressionToVariables.get(node).getName()); } Expression translated = expressionToExpressions.getOrDefault(node, node); @@ -127,50 +129,50 @@ public Expression rewriteExpression(Expression node, Void context, ExpressionTre }, mapped); } - public void put(Expression expression, Symbol symbol) + public void put(Expression expression, VariableReferenceExpression variable) { if (expression instanceof FieldReference) { int fieldIndex = ((FieldReference) expression).getFieldIndex(); - fieldSymbols[fieldIndex] = symbol; - expressionToSymbols.put(rewriteBase.getSymbol(fieldIndex).toSymbolReference(), symbol); + fieldVariables[fieldIndex] = variable; + expressionToVariables.put(new SymbolReference(rewriteBase.getVariable(fieldIndex).getName()), variable); return; } Expression translated = translateNamesToSymbols(expression); - expressionToSymbols.put(translated, symbol); + expressionToVariables.put(translated, variable); // also update the field mappings if this expression is a field reference rewriteBase.getScope().tryResolveField(expression) .filter(ResolvedField::isLocal) - .ifPresent(field -> fieldSymbols[field.getHierarchyFieldIndex()] = symbol); + .ifPresent(field -> fieldVariables[field.getHierarchyFieldIndex()] = variable); } public boolean containsSymbol(Expression expression) { if (expression instanceof FieldReference) { int field = ((FieldReference) expression).getFieldIndex(); - return fieldSymbols[field] != null; + return fieldVariables[field] != null; } Expression translated = translateNamesToSymbols(expression); - return expressionToSymbols.containsKey(translated); + return expressionToVariables.containsKey(translated); } - public Symbol get(Expression expression) + public VariableReferenceExpression get(Expression expression) { if (expression instanceof FieldReference) { int field = ((FieldReference) expression).getFieldIndex(); - checkArgument(fieldSymbols[field] != null, "No mapping for field: %s", field); - return fieldSymbols[field]; + checkArgument(fieldVariables[field] != null, "No mapping for field: %s", field); + return fieldVariables[field]; } Expression translated = translateNamesToSymbols(expression); - if (!expressionToSymbols.containsKey(translated)) { + if (!expressionToVariables.containsKey(translated)) { checkArgument(expressionToExpressions.containsKey(translated), "No mapping for expression: %s", expression); return get(expressionToExpressions.get(translated)); } - return expressionToSymbols.get(translated); + return expressionToVariables.get(translated); } public void put(Expression expression, Expression rewritten) @@ -202,8 +204,8 @@ public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTre { LambdaArgumentDeclaration referencedLambdaArgumentDeclaration = analysis.getLambdaArgumentReference(node); if (referencedLambdaArgumentDeclaration != null) { - Symbol symbol = lambdaDeclarationToSymbolMap.get(NodeRef.of(referencedLambdaArgumentDeclaration)); - return coerceIfNecessary(node, symbol.toSymbolReference()); + VariableReferenceExpression variable = lambdaDeclarationToVariableMap.get(NodeRef.of(referencedLambdaArgumentDeclaration)); + return coerceIfNecessary(node, new SymbolReference(variable.getName())); } else { return rewriteExpressionWithResolvedName(node); @@ -212,8 +214,8 @@ public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTre private Expression rewriteExpressionWithResolvedName(Expression node) { - return getSymbol(rewriteBase, node) - .map(symbol -> coerceIfNecessary(node, symbol.toSymbolReference())) + return getVariable(rewriteBase, node) + .map(variable -> coerceIfNecessary(node, new SymbolReference(variable.getName()))) .orElse(coerceIfNecessary(node, node)); } @@ -224,8 +226,8 @@ public Expression rewriteDereferenceExpression(DereferenceExpression node, Void Optional resolvedField = rewriteBase.getScope().tryResolveField(node); if (resolvedField.isPresent()) { if (resolvedField.get().isLocal()) { - return getSymbol(rewriteBase, node) - .map(symbol -> coerceIfNecessary(node, symbol.toSymbolReference())) + return getVariable(rewriteBase, node) + .map(variable -> coerceIfNecessary(node, new SymbolReference(variable.getName()))) .orElseThrow(() -> new IllegalStateException("No symbol mapping for node " + node)); } } @@ -242,8 +244,8 @@ public Expression rewriteLambdaExpression(LambdaExpression node, Void context, E ImmutableList.Builder newArguments = ImmutableList.builder(); for (LambdaArgumentDeclaration argument : node.getArguments()) { - Symbol symbol = lambdaDeclarationToSymbolMap.get(NodeRef.of(argument)); - newArguments.add(new LambdaArgumentDeclaration(new Identifier(symbol.getName()))); + VariableReferenceExpression variable = lambdaDeclarationToVariableMap.get(NodeRef.of(argument)); + newArguments.add(new LambdaArgumentDeclaration(new Identifier(variable.getName()))); } Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), null); return new LambdaExpression(newArguments.build(), rewrittenBody); @@ -271,7 +273,7 @@ private Expression coerceIfNecessary(Expression original, Expression rewritten) }, expression, null); } - private Optional getSymbol(RelationPlan plan, Expression expression) + private Optional getVariable(RelationPlan plan, Expression expression) { if (!analysis.isColumnReference(expression)) { // Expression can be a reference to lambda argument (or DereferenceExpression based on lambda argument reference). diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java index 1bd17351e80c6..bc2f3a27dae55 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/TypeProvider.java @@ -13,13 +13,19 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.google.common.collect.ImmutableMap; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Map; +import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; public class TypeProvider @@ -41,6 +47,16 @@ public static TypeProvider empty() return new TypeProvider(ImmutableMap.of()); } + public static TypeProvider fromVariables(VariableReferenceExpression... variables) + { + return fromVariables(Arrays.asList(variables)); + } + + public static TypeProvider fromVariables(Collection variables) + { + return new TypeProvider(variables.stream().collect(toImmutableMap(variable -> new Symbol(variable.getName()), VariableReferenceExpression::getType))); + } + private TypeProvider(Map types) { this.types = types; @@ -61,4 +77,11 @@ public Map allTypes() // types may be a HashMap, so creating an ImmutableMap here would add extra cost when allTypes gets called frequently return Collections.unmodifiableMap(types); } + + public Set allVariables() + { + return types.entrySet().stream() + .map(entry -> new VariableReferenceExpression(entry.getKey().getName(), entry.getValue())) + .collect(toImmutableSet()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/GroupReference.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/GroupReference.java index 6028c72aba666..5c7c11afc3f15 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/GroupReference.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/GroupReference.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.iterative; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.InternalPlanNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -26,9 +26,9 @@ public class GroupReference extends InternalPlanNode { private final int groupId; - private final List outputs; + private final List outputs; - public GroupReference(PlanNodeId id, int groupId, List outputs) + public GroupReference(PlanNodeId id, int groupId, List outputs) { super(id); this.groupId = groupId; @@ -53,7 +53,7 @@ public R accept(InternalPlanVisitor visitor, C context) } @Override - public List getOutputSymbols() + public List getOutputVariables() { return outputs; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java index f28b3218c4a27..6c109d9a1edc0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java @@ -114,11 +114,11 @@ public PlanNode replace(int group, PlanNode node, String reason) { PlanNode old = getGroup(group).membership; - checkArgument(new HashSet<>(old.getOutputSymbols()).equals(new HashSet<>(node.getOutputSymbols())), + checkArgument(new HashSet<>(old.getOutputVariables()).equals(new HashSet<>(node.getOutputVariables())), "%s: transformed expression doesn't produce same outputs: %s vs %s", reason, - old.getOutputSymbols(), - node.getOutputSymbols()); + old.getOutputVariables(), + node.getOutputVariables()); if (node instanceof GroupReference) { node = getNode(((GroupReference) node).getGroupId()); @@ -215,7 +215,7 @@ private PlanNode insertChildrenAndRewrite(PlanNode node) .map(child -> new GroupReference( idAllocator.getNextId(), insertRecursive(child), - child.getOutputSymbols())) + child.getOutputVariables())) .collect(Collectors.toList())); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java index a0f45152d792b..c24f51a4ffc8e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java @@ -17,7 +17,8 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -25,6 +26,7 @@ import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -34,7 +36,7 @@ import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency; import static com.facebook.presto.SystemSessionProperties.isEnableIntermediateAggregations; import static com.facebook.presto.matching.Pattern.empty; -import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractUnique; +import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractUniqueVariables; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.INTERMEDIATE; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.PARTIAL; @@ -100,8 +102,9 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont Lookup lookup = context.getLookup(); PlanNodeIdAllocator idAllocator = context.getIdAllocator(); Session session = context.getSession(); + TypeProvider types = context.getSymbolAllocator().getTypes(); - Optional rewrittenSource = recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator); + Optional rewrittenSource = recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator, types); if (!rewrittenSource.isPresent()) { return Result.empty(); @@ -114,12 +117,12 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont source = new AggregationNode( idAllocator.getNextId(), source, - inputsAsOutputs(aggregation.getAggregations()), + inputsAsOutputs(aggregation.getAggregations(), types), aggregation.getGroupingSets(), - aggregation.getPreGroupedSymbols(), + aggregation.getPreGroupedVariables(), INTERMEDIATE, - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol()); + aggregation.getHashVariable(), + aggregation.getGroupIdVariable()); source = gatheringExchange(idAllocator.getNextId(), LOCAL, source); } @@ -129,10 +132,10 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont /** * Recurse through a series of preceding ExchangeNodes and ProjectNodes to find the preceding PARTIAL aggregation */ - private Optional recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator) + private Optional recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, TypeProvider types) { if (node instanceof AggregationNode && ((AggregationNode) node).getStep() == PARTIAL) { - return Optional.of(addGatheringIntermediate((AggregationNode) node, idAllocator)); + return Optional.of(addGatheringIntermediate((AggregationNode) node, idAllocator, types)); } if (!(node instanceof ExchangeNode) && !(node instanceof ProjectNode)) { @@ -141,7 +144,7 @@ private Optional recurseToPartial(PlanNode node, Lookup lookup, PlanNo ImmutableList.Builder builder = ImmutableList.builder(); for (PlanNode source : node.getSources()) { - Optional planNode = recurseToPartial(lookup.resolve(source), lookup, idAllocator); + Optional planNode = recurseToPartial(lookup.resolve(source), lookup, idAllocator, types); if (!planNode.isPresent()) { return Optional.empty(); } @@ -150,7 +153,7 @@ private Optional recurseToPartial(PlanNode node, Lookup lookup, PlanNo return Optional.of(node.replaceChildren(builder.build())); } - private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) + private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator, TypeProvider types) { verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation"); ExchangeNode gatheringExchange = gatheringExchange(idAllocator.getNextId(), LOCAL, aggregation); @@ -159,10 +162,10 @@ private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeI gatheringExchange, outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), - aggregation.getPreGroupedSymbols(), + aggregation.getPreGroupedVariables(), INTERMEDIATE, - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol()); + aggregation.getHashVariable(), + aggregation.getGroupIdVariable()); } /** @@ -172,18 +175,18 @@ private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeI * 'a' := sum('b') => 'a' := sum('a') * 'a' := count(*) => 'a' := count('a') */ - private static Map outputsAsInputs(Map assignments) + private static Map outputsAsInputs(Map assignments) { - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (Map.Entry entry : assignments.entrySet()) { - Symbol output = entry.getKey(); + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (Map.Entry entry : assignments.entrySet()) { + VariableReferenceExpression output = entry.getKey(); Aggregation aggregation = entry.getValue(); checkState(!aggregation.getOrderBy().isPresent(), "Intermediate aggregation does not support ORDER BY"); builder.put( output, new Aggregation( aggregation.getFunctionHandle(), - ImmutableList.of(output.toSymbolReference()), + ImmutableList.of(new SymbolReference(output.getName())), Optional.empty(), Optional.empty(), false, @@ -200,16 +203,16 @@ private static Map outputsAsInputs(Map * Example: * 'a' := sum('b') => 'b' := sum('b') */ - private static Map inputsAsOutputs(Map assignments) + private static Map inputsAsOutputs(Map assignments, TypeProvider types) { - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (Map.Entry entry : assignments.entrySet()) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (Map.Entry entry : assignments.entrySet()) { // Should only have one input symbol Aggregation aggregation = entry.getValue(); checkArgument( aggregation.getArguments().size() == 1 && !aggregation.getOrderBy().isPresent() && !aggregation.getFilter().isPresent(), "Aggregation should only have one argument and should have no order by or filter to be able to rewritten to intermediate form"); - Symbol input = getOnlyElement(extractUnique(entry.getValue())); + VariableReferenceExpression input = getOnlyElement(extractUniqueVariables(entry.getValue(), types)); builder.put(input, entry.getValue()); } return builder.build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DesugarLambdaExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DesugarLambdaExpression.java index f8a6f89bde5d4..c7140d563e5c3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DesugarLambdaExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DesugarLambdaExpression.java @@ -40,6 +40,6 @@ public Set> rules() private static Expression rewrite(Expression expression, Rule.Context context) { - return LambdaCaptureDesugaringRewriter.rewrite(expression, context.getSymbolAllocator().getTypes(), context.getSymbolAllocator()); + return LambdaCaptureDesugaringRewriter.rewrite(expression, context.getSymbolAllocator()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java index f263cccdd2b74..c29403fa0485d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java @@ -82,7 +82,7 @@ public static boolean isBelowMaxBroadcastSize(JoinNode joinNode, Context context PlanNode buildSide = joinNode.getRight(); PlanNodeStatsEstimate buildSideStatsEstimate = context.getStatsProvider().getStats(buildSide); - double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide.getOutputSymbols(), context.getSymbolAllocator().getTypes()); + double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide.getOutputVariables()); return buildSideSizeInBytes <= joinMaxBroadcastTableSize.get().toBytes(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java index 2c65060ef4083..aeaa8e0e70e70 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java @@ -123,7 +123,7 @@ private boolean canReplicate(SemiJoinNode node, Context context) PlanNode buildSide = node.getFilteringSource(); PlanNodeStatsEstimate buildSideStatsEstimate = context.getStatsProvider().getStats(buildSide); - double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide.getOutputSymbols(), context.getSymbolAllocator().getTypes()); + double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide.getOutputVariables()); return buildSideSizeInBytes <= joinMaxBroadcastTableSize.get().toBytes(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java index 440b1674e6590..a707a76e76d27 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -18,8 +18,8 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; import com.facebook.presto.sql.planner.plan.Assignments; @@ -83,7 +83,7 @@ public Result apply(JoinNode node, Captures captures, Context context) return Result.empty(); } - PlanNode replacement = buildJoinTree(node.getOutputSymbols(), joinGraph, joinOrder, context.getIdAllocator()); + PlanNode replacement = buildJoinTree(node.getOutputVariables(), joinGraph, joinOrder, context.getIdAllocator()); return Result.ofPlanNode(replacement); } @@ -147,9 +147,9 @@ public static List getJoinOrder(JoinGraph graph) .collect(toImmutableList()); } - public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGraph graph, List joinOrder, PlanNodeIdAllocator idAllocator) + public static PlanNode buildJoinTree(List expectedOutputVariables, JoinGraph graph, List joinOrder, PlanNodeIdAllocator idAllocator) { - requireNonNull(expectedOutputSymbols, "expectedOutputSymbols is null"); + requireNonNull(expectedOutputVariables, "expectedOutputVariables is null"); requireNonNull(idAllocator, "idAllocator is null"); requireNonNull(graph, "graph is null"); joinOrder = ImmutableList.copyOf(requireNonNull(joinOrder, "joinOrder is null")); @@ -169,8 +169,8 @@ public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGra PlanNode targetNode = edge.getTargetNode(); if (alreadyJoinedNodes.contains(targetNode.getId())) { criteria.add(new JoinNode.EquiJoinClause( - edge.getTargetSymbol(), - edge.getSourceSymbol())); + edge.getTargetVariable(), + edge.getSourceVariable())); } } @@ -180,9 +180,9 @@ public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGra result, rightNode, criteria.build(), - ImmutableList.builder() - .addAll(result.getOutputSymbols()) - .addAll(rightNode.getOutputSymbols()) + ImmutableList.builder() + .addAll(result.getOutputVariables()) + .addAll(rightNode.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), @@ -208,6 +208,6 @@ public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGra // If needed, introduce a projection to constrain the outputs to what was originally expected // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) - return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputSymbols)).orElse(result); + return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputVariables)).orElse(result); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java index 71b73d7c510cc..c448c57dbe124 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java @@ -38,6 +38,6 @@ public Pattern getPattern() @Override public Result apply(LimitNode limit, Captures captures, Context context) { - return Result.ofPlanNode(new ValuesNode(limit.getId(), limit.getOutputSymbols(), ImmutableList.of())); + return Result.ofPlanNode(new ValuesNode(limit.getId(), limit.getOutputVariables(), ImmutableList.of())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java index d6d94680dcde2..e11271e944091 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java @@ -41,6 +41,6 @@ public Pattern getPattern() @Override public Result apply(SampleNode sample, Captures captures, Context context) { - return Result.ofPlanNode(new ValuesNode(sample.getId(), sample.getOutputSymbols(), ImmutableList.of())); + return Result.ofPlanNode(new ValuesNode(sample.getId(), sample.getOutputVariables(), ImmutableList.of())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index e31368279bb77..8a07301a10fbe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -16,7 +16,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -151,8 +151,8 @@ public Pattern getPattern() public Result apply(AggregationNode aggregationNode, Captures captures, Context context) { boolean anyRewritten = false; - ImmutableMap.Builder aggregations = ImmutableMap.builder(); - for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) { + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) { Aggregation aggregation = entry.getValue(); Aggregation rewritten = new Aggregation( aggregation.getFunctionHandle(), @@ -173,10 +173,10 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context aggregationNode.getSource(), aggregations.build(), aggregationNode.getGroupingSets(), - aggregationNode.getPreGroupedSymbols(), + aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), - aggregationNode.getHashSymbol(), - aggregationNode.getGroupIdSymbol())); + aggregationNode.getHashVariable(), + aggregationNode.getGroupIdVariable())); } return Result.empty(); } @@ -242,10 +242,10 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), - joinNode.getOutputSymbols(), + joinNode.getOutputVariables(), filter.map(OriginalExpressionUtils::castToRowExpression), - joinNode.getLeftHashSymbol(), - joinNode.getRightHashSymbol(), + joinNode.getLeftHashVariable(), + joinNode.getRightHashVariable(), joinNode.getDistributionType())); } return Result.empty(); @@ -292,7 +292,7 @@ public Result apply(ValuesNode valuesNode, Captures captures, Context context) rows.add(newRow.build()); } if (anyRewritten) { - return Result.ofPlanNode(new ValuesNode(valuesNode.getId(), valuesNode.getOutputSymbols(), rows.build())); + return Result.ofPlanNode(new ValuesNode(valuesNode.getId(), valuesNode.getOutputVariables(), rows.build())); } return Result.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java index bb66f175fb5e3..a274cd6beda30 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -32,6 +32,7 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; @@ -40,7 +41,6 @@ import com.facebook.presto.split.SplitSource; import com.facebook.presto.split.SplitSource.SplitBatch; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.Rule.Context; import com.facebook.presto.sql.planner.iterative.Rule.Result; @@ -87,7 +87,7 @@ import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.planner.ExpressionNodeInliner.replaceExpression; -import static com.facebook.presto.sql.planner.SymbolsExtractor.extractUnique; +import static com.facebook.presto.sql.planner.SymbolsExtractor.extractUniqueVariable; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.Patterns.filter; @@ -219,7 +219,7 @@ public Result apply(FilterNode node, Captures captures, Context context) Expression filter = OriginalExpressionUtils.castToExpression(node.getPredicate()); List spatialFunctions = extractSupportedSpatialFunctions(filter); for (FunctionCall spatialFunction : spatialFunctions) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputVariables(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); if (!result.isEmpty()) { return result; } @@ -227,7 +227,7 @@ public Result apply(FilterNode node, Captures captures, Context context) List spatialComparisons = extractSupportedSpatialComparisons(filter); for (ComparisonExpression spatialComparison : spatialComparisons) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputVariables(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); if (!result.isEmpty()) { return result; } @@ -274,7 +274,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) Expression filter = OriginalExpressionUtils.castToExpression(joinNode.getFilter().get()); List spatialFunctions = extractSupportedSpatialFunctions(filter); for (FunctionCall spatialFunction : spatialFunctions) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputVariables(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); if (!result.isEmpty()) { return result; } @@ -282,7 +282,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) List spatialComparisons = extractSupportedSpatialComparisons(filter); for (ComparisonExpression spatialComparison : spatialComparisons) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputVariables(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); if (!result.isEmpty()) { return result; } @@ -297,7 +297,7 @@ private static Result tryCreateSpatialJoin( JoinNode joinNode, Expression filter, PlanNodeId nodeId, - List outputSymbols, + List outputVariables, ComparisonExpression spatialComparison, Metadata metadata, SplitManager splitManager, @@ -307,19 +307,19 @@ private static Result tryCreateSpatialJoin( PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); - List leftSymbols = leftNode.getOutputSymbols(); - List rightSymbols = rightNode.getOutputSymbols(); + List leftVariables = leftNode.getOutputVariables(); + List rightVariables = rightNode.getOutputVariables(); Expression radius; - Optional newRadiusSymbol; + Optional newRadiusVariable; ComparisonExpression newComparison; if (spatialComparison.getOperator() == LESS_THAN || spatialComparison.getOperator() == LESS_THAN_OR_EQUAL) { // ST_Distance(a, b) <= r radius = spatialComparison.getRight(); - Set radiusSymbols = extractUnique(radius); - if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) { - newRadiusSymbol = newRadiusSymbol(context, radius); - newComparison = new ComparisonExpression(spatialComparison.getOperator(), spatialComparison.getLeft(), toExpression(newRadiusSymbol, radius)); + Set radiusVariables = extractUniqueVariable(radius, context.getSymbolAllocator().getTypes()); + if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) { + newRadiusVariable = newRadiusVariable(context, radius); + newComparison = new ComparisonExpression(spatialComparison.getOperator(), spatialComparison.getLeft(), toExpression(newRadiusVariable, radius)); } else { return Result.empty(); @@ -328,10 +328,10 @@ private static Result tryCreateSpatialJoin( else { // r >= ST_Distance(a, b) radius = spatialComparison.getLeft(); - Set radiusSymbols = extractUnique(radius); - if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) { - newRadiusSymbol = newRadiusSymbol(context, radius); - newComparison = new ComparisonExpression(spatialComparison.getOperator().flip(), spatialComparison.getRight(), toExpression(newRadiusSymbol, radius)); + Set radiusVariables = extractUniqueVariable(radius, context.getSymbolAllocator().getTypes()); + if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) { + newRadiusVariable = newRadiusVariable(context, radius); + newComparison = new ComparisonExpression(spatialComparison.getOperator().flip(), spatialComparison.getRight(), toExpression(newRadiusVariable, radius)); } else { return Result.empty(); @@ -339,7 +339,7 @@ private static Result tryCreateSpatialJoin( } Expression newFilter = replaceExpression(filter, ImmutableMap.of(spatialComparison, newComparison)); - PlanNode newRightNode = newRadiusSymbol.map(symbol -> addProjection(context, rightNode, symbol, radius)).orElse(rightNode); + PlanNode newRightNode = newRadiusVariable.map(variable -> addProjection(context, rightNode, variable, radius)).orElse(rightNode); JoinNode newJoinNode = new JoinNode( joinNode.getId(), @@ -347,13 +347,13 @@ private static Result tryCreateSpatialJoin( leftNode, newRightNode, joinNode.getCriteria(), - joinNode.getOutputSymbols(), + joinNode.getOutputVariables(), Optional.of(castToRowExpression(newFilter)), - joinNode.getLeftHashSymbol(), - joinNode.getRightHashSymbol(), + joinNode.getLeftHashVariable(), + joinNode.getRightHashVariable(), joinNode.getDistributionType()); - return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), metadata, splitManager, pageSourceManager, sqlParser); + return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), metadata, splitManager, pageSourceManager, sqlParser); } private static Result tryCreateSpatialJoin( @@ -361,7 +361,7 @@ private static Result tryCreateSpatialJoin( JoinNode joinNode, Expression filter, PlanNodeId nodeId, - List outputSymbols, + List outputVariables, FunctionCall spatialFunction, Optional radius, Metadata metadata, @@ -385,15 +385,15 @@ private static Result tryCreateSpatialJoin( return Result.empty(); } - Set firstSymbols = extractUnique(firstArgument); - Set secondSymbols = extractUnique(secondArgument); + Set firstVariables = extractUniqueVariable(firstArgument, context.getSymbolAllocator().getTypes()); + Set secondVariables = extractUniqueVariable(secondArgument, context.getSymbolAllocator().getTypes()); - if (firstSymbols.isEmpty() || secondSymbols.isEmpty()) { + if (firstVariables.isEmpty() || secondVariables.isEmpty()) { return Result.empty(); } - Optional newFirstSymbol = newGeometrySymbol(context, firstArgument, metadata); - Optional newSecondSymbol = newGeometrySymbol(context, secondArgument, metadata); + Optional newFirstVariable = newGeometryVariable(context, firstArgument, metadata); + Optional newSecondVariable = newGeometryVariable(context, secondArgument, metadata); PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); @@ -402,35 +402,35 @@ private static Result tryCreateSpatialJoin( PlanNode newRightNode; // Check if the order of arguments of the spatial function matches the order of join sides - int alignment = checkAlignment(joinNode, firstSymbols, secondSymbols); + int alignment = checkAlignment(joinNode, firstVariables, secondVariables); if (alignment > 0) { - newLeftNode = newFirstSymbol.map(symbol -> addProjection(context, leftNode, symbol, firstArgument)).orElse(leftNode); - newRightNode = newSecondSymbol.map(symbol -> addProjection(context, rightNode, symbol, secondArgument)).orElse(rightNode); + newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode); + newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode); } else if (alignment < 0) { - newLeftNode = newSecondSymbol.map(symbol -> addProjection(context, leftNode, symbol, secondArgument)).orElse(leftNode); - newRightNode = newFirstSymbol.map(symbol -> addProjection(context, rightNode, symbol, firstArgument)).orElse(rightNode); + newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode); + newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode); } else { return Result.empty(); } - Expression newFirstArgument = toExpression(newFirstSymbol, firstArgument); - Expression newSecondArgument = toExpression(newSecondSymbol, secondArgument); + Expression newFirstArgument = toExpression(newFirstVariable, firstArgument); + Expression newSecondArgument = toExpression(newSecondVariable, secondArgument); - Optional leftPartitionSymbol = Optional.empty(); - Optional rightPartitionSymbol = Optional.empty(); + Optional leftPartitionVariable = Optional.empty(); + Optional rightPartitionVariable = Optional.empty(); if (kdbTree.isPresent()) { - leftPartitionSymbol = Optional.of(context.getSymbolAllocator().newSymbol("pid", INTEGER)); - rightPartitionSymbol = Optional.of(context.getSymbolAllocator().newSymbol("pid", INTEGER)); + leftPartitionVariable = Optional.of(context.getSymbolAllocator().newVariable("pid", INTEGER)); + rightPartitionVariable = Optional.of(context.getSymbolAllocator().newVariable("pid", INTEGER)); if (alignment > 0) { - newLeftNode = addPartitioningNodes(context, newLeftNode, leftPartitionSymbol.get(), kdbTree.get(), newFirstArgument, Optional.empty()); - newRightNode = addPartitioningNodes(context, newRightNode, rightPartitionSymbol.get(), kdbTree.get(), newSecondArgument, radius); + newLeftNode = addPartitioningNodes(context, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newFirstArgument, Optional.empty()); + newRightNode = addPartitioningNodes(context, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newSecondArgument, radius); } else { - newLeftNode = addPartitioningNodes(context, newLeftNode, leftPartitionSymbol.get(), kdbTree.get(), newSecondArgument, Optional.empty()); - newRightNode = addPartitioningNodes(context, newRightNode, rightPartitionSymbol.get(), kdbTree.get(), newFirstArgument, radius); + newLeftNode = addPartitioningNodes(context, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newSecondArgument, Optional.empty()); + newRightNode = addPartitioningNodes(context, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newFirstArgument, radius); } } @@ -442,10 +442,10 @@ else if (alignment < 0) { SpatialJoinNode.Type.fromJoinNodeType(joinNode.getType()), newLeftNode, newRightNode, - outputSymbols, + outputVariables, castToRowExpression(newFilter), - leftPartitionSymbol, - rightPartitionSymbol, + leftPartitionVariable, + rightPartitionVariable, kdbTree.map(KdbTreeUtils::toJson))); } @@ -538,67 +538,67 @@ private static QualifiedObjectName toQualifiedObjectName(String name, String cat throw new PrestoException(INVALID_SPATIAL_PARTITIONING, format("Invalid name: %s", name)); } - private static int checkAlignment(JoinNode joinNode, Set maybeLeftSymbols, Set maybeRightSymbols) + private static int checkAlignment(JoinNode joinNode, Set maybeLeftVariables, Set maybeRightVariables) { - List leftSymbols = joinNode.getLeft().getOutputSymbols(); - List rightSymbols = joinNode.getRight().getOutputSymbols(); + List leftVariables = joinNode.getLeft().getOutputVariables(); + List rightVariables = joinNode.getRight().getOutputVariables(); - if (leftSymbols.containsAll(maybeLeftSymbols) - && containsNone(leftSymbols, maybeRightSymbols) - && rightSymbols.containsAll(maybeRightSymbols) - && containsNone(rightSymbols, maybeLeftSymbols)) { + if (leftVariables.containsAll(maybeLeftVariables) + && containsNone(leftVariables, maybeRightVariables) + && rightVariables.containsAll(maybeRightVariables) + && containsNone(rightVariables, maybeLeftVariables)) { return 1; } - if (leftSymbols.containsAll(maybeRightSymbols) - && containsNone(leftSymbols, maybeLeftSymbols) - && rightSymbols.containsAll(maybeLeftSymbols) - && containsNone(rightSymbols, maybeRightSymbols)) { + if (leftVariables.containsAll(maybeRightVariables) + && containsNone(leftVariables, maybeLeftVariables) + && rightVariables.containsAll(maybeLeftVariables) + && containsNone(rightVariables, maybeRightVariables)) { return -1; } return 0; } - private static Expression toExpression(Optional optionalSymbol, Expression defaultExpression) + private static Expression toExpression(Optional optionalVariable, Expression defaultExpression) { - return optionalSymbol.map(symbol -> (Expression) symbol.toSymbolReference()).orElse(defaultExpression); + return optionalVariable.map(variable -> (Expression) new SymbolReference(variable.getName())).orElse(defaultExpression); } - private static Optional newGeometrySymbol(Context context, Expression expression, Metadata metadata) + private static Optional newGeometryVariable(Context context, Expression expression, Metadata metadata) { if (expression instanceof SymbolReference) { return Optional.empty(); } - return Optional.of(context.getSymbolAllocator().newSymbol(expression, metadata.getType(GEOMETRY_TYPE_SIGNATURE))); + return Optional.of(context.getSymbolAllocator().newVariable(expression, metadata.getType(GEOMETRY_TYPE_SIGNATURE))); } - private static Optional newRadiusSymbol(Context context, Expression expression) + private static Optional newRadiusVariable(Context context, Expression expression) { if (expression instanceof SymbolReference) { return Optional.empty(); } - return Optional.of(context.getSymbolAllocator().newSymbol(expression, DOUBLE)); + return Optional.of(context.getSymbolAllocator().newVariable(expression, DOUBLE)); } - private static PlanNode addProjection(Context context, PlanNode node, Symbol symbol, Expression expression) + private static PlanNode addProjection(Context context, PlanNode node, VariableReferenceExpression variable, Expression expression) { Assignments.Builder projections = Assignments.builder(); - for (Symbol outputSymbol : node.getOutputSymbols()) { - projections.putIdentity(outputSymbol); + for (VariableReferenceExpression outputVariable : node.getOutputVariables()) { + projections.putIdentity(outputVariable); } - projections.put(symbol, expression); + projections.put(variable, expression); return new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build()); } - private static PlanNode addPartitioningNodes(Context context, PlanNode node, Symbol partitionSymbol, KdbTree kdbTree, Expression geometry, Optional radius) + private static PlanNode addPartitioningNodes(Context context, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, Expression geometry, Optional radius) { Assignments.Builder projections = Assignments.builder(); - for (Symbol outputSymbol : node.getOutputSymbols()) { - projections.putIdentity(outputSymbol); + for (VariableReferenceExpression outputVariable : node.getOutputVariables()) { + projections.putIdentity(outputVariable); } ImmutableList.Builder partitioningArguments = ImmutableList.builder() @@ -607,18 +607,18 @@ private static PlanNode addPartitioningNodes(Context context, PlanNode node, Sym radius.map(partitioningArguments::add); FunctionCall partitioningFunction = new FunctionCall(QualifiedName.of("spatial_partitions"), partitioningArguments.build()); - Symbol partitionsSymbol = context.getSymbolAllocator().newSymbol(partitioningFunction, new ArrayType(INTEGER)); - projections.put(partitionsSymbol, partitioningFunction); + VariableReferenceExpression partitionsVariable = context.getSymbolAllocator().newVariable(partitioningFunction, new ArrayType(INTEGER)); + projections.put(partitionsVariable, partitioningFunction); return new UnnestNode( context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build()), - node.getOutputSymbols(), - ImmutableMap.of(partitionsSymbol, ImmutableList.of(partitionSymbol)), + node.getOutputVariables(), + ImmutableMap.of(partitionsVariable, ImmutableList.of(partitionVariable)), Optional.empty()); } - private static boolean containsNone(Collection values, Collection testValues) + private static boolean containsNone(Collection values, Collection testValues) { return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java index b4aa55db2c23e..2542da904de3f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java @@ -17,9 +17,8 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.matching.PropertyPattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -40,6 +39,7 @@ import java.util.stream.Stream; import static com.facebook.presto.matching.Capture.newCapture; +import static com.facebook.presto.sql.planner.SymbolsExtractor.extractUniqueVariable; import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; import static com.facebook.presto.sql.planner.iterative.rule.Util.transpose; import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.dependsOn; @@ -103,7 +103,7 @@ public Result apply(WindowNode parent, Captures captures, Context context) .map(captures::get) .collect(toImmutableList()); - return pullWindowNodeAboveProjects(captures.get(childCapture), projects) + return pullWindowNodeAboveProjects(captures.get(childCapture), projects, context) .flatMap(newChild -> manipulateAdjacentWindowNodes(parent, newChild, context)) .map(Result::ofPlanNode) .orElse(Result.empty()); @@ -120,7 +120,8 @@ public Result apply(WindowNode parent, Captures captures, Context context) */ protected static Optional pullWindowNodeAboveProjects( WindowNode target, - List projects) + List projects, + Context context) { if (projects.isEmpty()) { return Optional.of(target); @@ -128,17 +129,17 @@ protected static Optional pullWindowNodeAboveProjects( PlanNode targetChild = target.getSource(); - Set targetInputs = ImmutableSet.copyOf(targetChild.getOutputSymbols()); - Set targetOutputs = ImmutableSet.copyOf(target.getOutputSymbols()); + Set targetInputs = ImmutableSet.copyOf(targetChild.getOutputVariables()); + Set targetOutputs = ImmutableSet.copyOf(target.getOutputVariables()); PlanNode newTargetChild = targetChild; for (ProjectNode project : projects) { - Set newTargetChildOutputs = ImmutableSet.copyOf(newTargetChild.getOutputSymbols()); + Set newTargetChildOutputs = ImmutableSet.copyOf(newTargetChild.getOutputVariables()); // The only kind of use of the output of the target that we can safely ignore is a simple identity propagation. // The target node, when hoisted above the projections, will provide the symbols directly. - Map assignmentsWithoutTargetOutputIdentities = Maps.filterKeys( + Map assignmentsWithoutTargetOutputIdentities = Maps.filterKeys( project.getAssignments().getMap(), output -> !(project.getAssignments().isIdentity(output) && targetOutputs.contains(output))); @@ -152,7 +153,7 @@ protected static Optional pullWindowNodeAboveProjects( .putIdentities(targetInputs) .build(); - if (!newTargetChildOutputs.containsAll(SymbolsExtractor.extractUnique(newAssignments.getExpressions()))) { + if (!newTargetChildOutputs.containsAll(extractUniqueVariable(newAssignments.getExpressions(), context.getSymbolAllocator().getTypes()))) { // Projection uses an output of the target -- can't move the target above this projection. return Optional.empty(); } @@ -161,8 +162,8 @@ protected static Optional pullWindowNodeAboveProjects( } WindowNode newTarget = (WindowNode) target.replaceChildren(ImmutableList.of(newTargetChild)); - Set newTargetOutputs = ImmutableSet.copyOf(newTarget.getOutputSymbols()); - if (!newTargetOutputs.containsAll(projects.get(projects.size() - 1).getOutputSymbols())) { + Set newTargetOutputs = ImmutableSet.copyOf(newTarget.getOutputVariables()); + if (!newTargetOutputs.containsAll(projects.get(projects.size() - 1).getOutputVariables())) { // The new target node is hiding some of the projections, which makes this rewrite incorrect. return Optional.empty(); } @@ -181,11 +182,11 @@ public MergeAdjacentWindowsOverProjects(int numProjects) @Override protected Optional manipulateAdjacentWindowNodes(WindowNode parent, WindowNode child, Context context) { - if (!child.getSpecification().equals(parent.getSpecification()) || dependsOn(parent, child)) { + if (!child.getSpecification().equals(parent.getSpecification()) || dependsOn(parent, child, context.getSymbolAllocator().getTypes())) { return Optional.empty(); } - ImmutableMap.Builder functionsBuilder = ImmutableMap.builder(); + ImmutableMap.Builder functionsBuilder = ImmutableMap.builder(); functionsBuilder.putAll(parent.getWindowFunctions()); functionsBuilder.putAll(child.getWindowFunctions()); @@ -194,12 +195,12 @@ protected Optional manipulateAdjacentWindowNodes(WindowNode parent, Wi child.getSource(), parent.getSpecification(), functionsBuilder.build(), - parent.getHashSymbol(), + parent.getHashVariable(), parent.getPrePartitionedInputs(), parent.getPreSortedOrderPrefix()); return Optional.of( - restrictOutputs(context.getIdAllocator(), mergedWindowNode, ImmutableSet.copyOf(parent.getOutputSymbols())) + restrictOutputs(context.getIdAllocator(), mergedWindowNode, ImmutableSet.copyOf(parent.getOutputVariables())) .orElse(mergedWindowNode)); } } @@ -215,10 +216,10 @@ public SwapAdjacentWindowsBySpecifications(int numProjects) @Override protected Optional manipulateAdjacentWindowNodes(WindowNode parent, WindowNode child, Context context) { - if ((compare(parent, child) < 0) && (!dependsOn(parent, child))) { + if ((compare(parent, child) < 0) && (!dependsOn(parent, child, context.getSymbolAllocator().getTypes()))) { PlanNode transposedWindows = transpose(parent, child); return Optional.of( - restrictOutputs(context.getIdAllocator(), transposedWindows, ImmutableSet.copyOf(parent.getOutputSymbols())) + restrictOutputs(context.getIdAllocator(), transposedWindows, ImmutableSet.copyOf(parent.getOutputVariables())) .orElse(transposedWindows)); } else { @@ -244,14 +245,14 @@ private static int compare(WindowNode o1, WindowNode o2) private static int comparePartitionBy(WindowNode o1, WindowNode o2) { - Iterator iterator1 = o1.getPartitionBy().iterator(); - Iterator iterator2 = o2.getPartitionBy().iterator(); + Iterator iterator1 = o1.getPartitionBy().iterator(); + Iterator iterator2 = o2.getPartitionBy().iterator(); while (iterator1.hasNext() && iterator2.hasNext()) { - Symbol symbol1 = iterator1.next(); - Symbol symbol2 = iterator2.next(); + VariableReferenceExpression variable1 = iterator1.next(); + VariableReferenceExpression variable2 = iterator2.next(); - int partitionByComparison = symbol1.compareTo(symbol2); + int partitionByComparison = variable1.compareTo(variable2); if (partitionByComparison != 0) { return partitionByComparison; } @@ -280,19 +281,19 @@ else if (!o1.getOrderingScheme().isPresent() && o2.getOrderingScheme().isPresent OrderingScheme o1OrderingScheme = o1.getOrderingScheme().get(); OrderingScheme o2OrderingScheme = o2.getOrderingScheme().get(); - Iterator iterator1 = o1OrderingScheme.getOrderBy().iterator(); - Iterator iterator2 = o2OrderingScheme.getOrderBy().iterator(); + Iterator iterator1 = o1OrderingScheme.getOrderBy().iterator(); + Iterator iterator2 = o2OrderingScheme.getOrderBy().iterator(); while (iterator1.hasNext() && iterator2.hasNext()) { - Symbol symbol1 = iterator1.next(); - Symbol symbol2 = iterator2.next(); + VariableReferenceExpression variable1 = iterator1.next(); + VariableReferenceExpression variable2 = iterator2.next(); - int orderByComparison = symbol1.compareTo(symbol2); + int orderByComparison = variable1.compareTo(variable2); if (orderByComparison != 0) { return orderByComparison; } else { - int sortOrderComparison = o1OrderingScheme.getOrdering(symbol1).compareTo(o2OrderingScheme.getOrdering(symbol2)); + int sortOrderComparison = o1OrderingScheme.getOrdering(variable1).compareTo(o2OrderingScheme.getOrdering(variable2)); if (sortOrderComparison != 0) { return sortOrderComparison; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index 109948828aa22..07d200e4419c9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -15,7 +15,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -23,6 +23,7 @@ import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -81,24 +82,24 @@ public Pattern getPattern() public Result apply(AggregationNode aggregation, Captures captures, Context context) { Assignments.Builder newAssignments = Assignments.builder(); - ImmutableMap.Builder aggregations = ImmutableMap.builder(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); ImmutableList.Builder maskSymbols = ImmutableList.builder(); boolean aggregateWithoutFilterPresent = false; - for (Map.Entry entry : aggregation.getAggregations().entrySet()) { - Symbol output = entry.getKey(); + for (Map.Entry entry : aggregation.getAggregations().entrySet()) { + VariableReferenceExpression output = entry.getKey(); // strip the filters - Optional mask = entry.getValue().getMask(); + Optional mask = entry.getValue().getMask(); if (entry.getValue().getFilter().isPresent()) { Expression filter = entry.getValue().getFilter().get(); - Symbol symbol = context.getSymbolAllocator().newSymbol(filter, BOOLEAN); + VariableReferenceExpression variable = context.getSymbolAllocator().newVariable(filter, BOOLEAN); verify(!mask.isPresent(), "Expected aggregation without mask symbols, see Rule pattern"); - newAssignments.put(symbol, filter); - mask = Optional.of(symbol); + newAssignments.put(variable, filter); + mask = Optional.of(variable); - maskSymbols.add(symbol.toSymbolReference()); + maskSymbols.add(new SymbolReference(variable.getName())); } else { aggregateWithoutFilterPresent = true; @@ -119,7 +120,7 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont } // identity projection for all existing inputs - newAssignments.putIdentities(aggregation.getSource().getOutputSymbols()); + newAssignments.putIdentities(aggregation.getSource().getOutputVariables()); return Result.ofPlanNode( new AggregationNode( @@ -135,7 +136,7 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont aggregation.getGroupingSets(), ImmutableList.of(), aggregation.getStep(), - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol())); + aggregation.getHashVariable(), + aggregation.getGroupIdVariable())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java index 56b1239d2121b..c2692540936cc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java @@ -16,8 +16,10 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -64,14 +66,14 @@ public Result apply(ProjectNode parent, Captures captures, Context context) { ProjectNode child = captures.get(CHILD); - Sets.SetView targets = extractInliningTargets(parent, child); + Sets.SetView targets = extractInliningTargets(parent, child, context); if (targets.isEmpty()) { return Result.empty(); } // inline the expressions Assignments assignments = child.getAssignments().filter(targets::contains); - Map parentAssignments = parent.getAssignments() + Map parentAssignments = parent.getAssignments() .entrySet().stream() .collect(Collectors.toMap( Map.Entry::getKey, @@ -81,20 +83,20 @@ public Result apply(ProjectNode parent, Captures captures, Context context) // to place in the child projection. // If all assignments end up becoming identity assignments, they'll get pruned by // other rules - Set inputs = child.getAssignments() + Set inputs = child.getAssignments() .entrySet().stream() .filter(entry -> targets.contains(entry.getKey())) .map(Map.Entry::getValue) - .flatMap(entry -> SymbolsExtractor.extractAll(entry).stream()) + .flatMap(entry -> SymbolsExtractor.extractAllVariable(entry, context.getSymbolAllocator().getTypes()).stream()) .collect(toSet()); Assignments.Builder childAssignments = Assignments.builder(); - for (Map.Entry assignment : child.getAssignments().entrySet()) { + for (Map.Entry assignment : child.getAssignments().entrySet()) { if (!targets.contains(assignment.getKey())) { childAssignments.put(assignment); } } - for (Symbol input : inputs) { + for (VariableReferenceExpression input : inputs) { childAssignments.putIdentity(input); } @@ -122,7 +124,7 @@ private Expression inlineReferences(Expression expression, Assignments assignmen return inlineSymbols(mapping, expression); } - private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectNode child) + private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectNode child, Context context) { // candidates for inlining are // 1. references to simple constants @@ -132,27 +134,27 @@ private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectN // c. are not identity projections // which come from the child, as opposed to an enclosing scope. - Set childOutputSet = ImmutableSet.copyOf(child.getOutputSymbols()); + Set childOutputSet = ImmutableSet.copyOf(child.getOutputVariables()); - Map dependencies = parent.getAssignments() + Map dependencies = parent.getAssignments() .getExpressions().stream() - .flatMap(expression -> SymbolsExtractor.extractAll(expression).stream()) + .flatMap(expression -> SymbolsExtractor.extractAllVariable(expression, context.getSymbolAllocator().getTypes()).stream()) .filter(childOutputSet::contains) .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); // find references to simple constants - Set constants = dependencies.keySet().stream() + Set constants = dependencies.keySet().stream() .filter(input -> child.getAssignments().get(input) instanceof Literal) .collect(toSet()); // exclude any complex inputs to TRY expressions. Inlining them would potentially // change the semantics of those expressions - Set tryArguments = parent.getAssignments() + Set tryArguments = parent.getAssignments() .getExpressions().stream() - .flatMap(expression -> extractTryArguments(expression).stream()) + .flatMap(expression -> extractTryArguments(expression, context.getSymbolAllocator().getTypes()).stream()) .collect(toSet()); - Set singletons = dependencies.entrySet().stream() + Set singletons = dependencies.entrySet().stream() .filter(entry -> entry.getValue() == 1) // reference appears just once across all expressions in parent project node .filter(entry -> !tryArguments.contains(entry.getKey())) // they are not inputs to TRY. Otherwise, inlining might change semantics .filter(entry -> !child.getAssignments().isIdentity(entry.getKey())) // skip identities, otherwise, this rule will keep firing forever @@ -162,12 +164,12 @@ private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectN return Sets.union(singletons, constants); } - private Set extractTryArguments(Expression expression) + private Set extractTryArguments(Expression expression, TypeProvider types) { return AstUtils.preOrder(expression) .filter(TryExpression.class::isInstance) .map(TryExpression.class::cast) - .flatMap(tryExpression -> SymbolsExtractor.extractAll(tryExpression).stream()) + .flatMap(tryExpression -> SymbolsExtractor.extractAllVariable(tryExpression, types).stream()) .collect(toSet()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter.java index 69a3dc1bbc99b..f0b1ddf611f04 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter.java @@ -14,9 +14,9 @@ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.tree.BindExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -33,15 +33,16 @@ import java.util.Set; import java.util.function.Function; -import static com.facebook.presto.sql.planner.ExpressionSymbolInliner.inlineSymbols; +import static com.facebook.presto.sql.planner.ExpressionVariableInliner.inlineVariables; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; public class LambdaCaptureDesugaringRewriter { - public static Expression rewrite(Expression expression, TypeProvider symbolTypes, SymbolAllocator symbolAllocator) + public static Expression rewrite(Expression expression, SymbolAllocator symbolAllocator) { - return ExpressionTreeRewriter.rewriteWith(new Visitor(symbolTypes, symbolAllocator), expression, new Context()); + return ExpressionTreeRewriter.rewriteWith(new Visitor(symbolAllocator), expression, new Context()); } private LambdaCaptureDesugaringRewriter() {} @@ -49,12 +50,10 @@ private LambdaCaptureDesugaringRewriter() {} private static class Visitor extends ExpressionRewriter { - private final TypeProvider symbolTypes; private final SymbolAllocator symbolAllocator; - public Visitor(TypeProvider symbolTypes, SymbolAllocator symbolAllocator) + public Visitor(SymbolAllocator symbolAllocator) { - this.symbolTypes = requireNonNull(symbolTypes, "symbolTypes is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); } @@ -62,8 +61,8 @@ public Visitor(TypeProvider symbolTypes, SymbolAllocator symbolAllocator) public Expression rewriteLambdaExpression(LambdaExpression node, Context context, ExpressionTreeRewriter treeRewriter) { // Use linked hash set to guarantee deterministic iteration order - LinkedHashSet referencedSymbols = new LinkedHashSet<>(); - Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), context.withReferencedSymbols(referencedSymbols)); + LinkedHashSet referencedVariables = new LinkedHashSet<>(); + Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), context.withReferencedVariables(referencedVariables)); List lambdaArguments = node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) @@ -71,42 +70,40 @@ public Expression rewriteLambdaExpression(LambdaExpression node, Context context .map(Symbol::new) .collect(toImmutableList()); - // referenced symbols - lambda arguments = capture symbols - // referencedSymbols no longer contains what its name suggests after this line - referencedSymbols.removeAll(lambdaArguments); - Set captureSymbols = referencedSymbols; + // referenced variables - lambda arguments = capture variables + Set captureVariables = referencedVariables.stream().filter(variable -> !lambdaArguments.contains(new Symbol(variable.getName()))).collect(toImmutableSet()); // x -> f(x, captureSymbol) will be rewritten into // "$internal$bind"(captureSymbol, (extraSymbol, x) -> f(x, extraSymbol)) - ImmutableMap.Builder captureSymbolToExtraSymbol = ImmutableMap.builder(); + ImmutableMap.Builder captureVariableToExtraVariable = ImmutableMap.builder(); ImmutableList.Builder newLambdaArguments = ImmutableList.builder(); - for (Symbol captureSymbol : captureSymbols) { - Symbol extraSymbol = symbolAllocator.newSymbol(captureSymbol.getName(), symbolTypes.get(captureSymbol)); - captureSymbolToExtraSymbol.put(captureSymbol, extraSymbol); - newLambdaArguments.add(new LambdaArgumentDeclaration(new Identifier(extraSymbol.getName()))); + for (VariableReferenceExpression captureVariable : captureVariables) { + VariableReferenceExpression extraVariable = symbolAllocator.newVariable(captureVariable); + captureVariableToExtraVariable.put(captureVariable, extraVariable); + newLambdaArguments.add(new LambdaArgumentDeclaration(new Identifier(extraVariable.getName()))); } newLambdaArguments.addAll(node.getArguments()); - ImmutableMap symbolsMap = captureSymbolToExtraSymbol.build(); - Function symbolMapping = symbol -> symbolsMap.getOrDefault(symbol, symbol).toSymbolReference(); - Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), inlineSymbols(symbolMapping, rewrittenBody)); + ImmutableMap symbolsMap = captureVariableToExtraVariable.build(); + Function variableMapping = variable -> new SymbolReference(symbolsMap.getOrDefault(variable, variable).getName()); + Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), inlineVariables(variableMapping, rewrittenBody, symbolAllocator.getTypes())); - if (captureSymbols.size() != 0) { - List capturedValues = captureSymbols.stream() + if (captureVariables.size() != 0) { + List capturedValues = captureVariables.stream() .map(symbol -> new SymbolReference(symbol.getName())) .collect(toImmutableList()); rewrittenExpression = new BindExpression(capturedValues, rewrittenExpression); } - context.getReferencedSymbols().addAll(captureSymbols); + context.getReferencedVariables().addAll(captureVariables); return rewrittenExpression; } @Override public Expression rewriteSymbolReference(SymbolReference node, Context context, ExpressionTreeRewriter treeRewriter) { - context.getReferencedSymbols().add(new Symbol(node.getName())); + context.getReferencedVariables().add(symbolAllocator.toVariableReference(Symbol.from(node))); return null; } } @@ -114,26 +111,26 @@ public Expression rewriteSymbolReference(SymbolReference node, Context context, private static class Context { // Use linked hash set to guarantee deterministic iteration order - final LinkedHashSet referencedSymbols; + final LinkedHashSet referencedVariables; public Context() { this(new LinkedHashSet<>()); } - private Context(LinkedHashSet referencedSymbols) + private Context(LinkedHashSet referencedVariables) { - this.referencedSymbols = referencedSymbols; + this.referencedVariables = referencedVariables; } - public LinkedHashSet getReferencedSymbols() + public LinkedHashSet getReferencedVariables() { - return referencedSymbols; + return referencedVariables; } - public Context withReferencedSymbols(LinkedHashSet symbols) + public Context withReferencedVariables(LinkedHashSet variables) { - return new Context(symbols); + return new Context(variables); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java index 19749c606a0d9..cf4be49a038e8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java @@ -41,8 +41,8 @@ public class MergeLimitWithDistinct private static boolean isDistinct(AggregationNode node) { return node.getAggregations().isEmpty() && - node.getOutputSymbols().size() == node.getGroupingKeys().size() && - node.getOutputSymbols().containsAll(node.getGroupingKeys()); + node.getOutputVariables().size() == node.getGroupingKeys().size() && + node.getOutputVariables().containsAll(node.getGroupingKeys()); } @Override @@ -63,6 +63,6 @@ public Result apply(LimitNode parent, Captures captures, Context context) parent.getCount(), false, child.getGroupingKeys(), - child.getHashSymbol())); + child.getHashVariable())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java index f0bd454fc00a9..09ddc41ba469b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java @@ -16,6 +16,7 @@ import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -112,34 +113,35 @@ public Result apply(AggregationNode parent, Captures captures, Context context) } // the distinct marker for the given set of input columns - Map, Symbol> markers = new HashMap<>(); + Map, VariableReferenceExpression> markers = new HashMap<>(); - Map newAggregations = new HashMap<>(); + Map newAggregations = new HashMap<>(); PlanNode subPlan = parent.getSource(); - for (Map.Entry entry : parent.getAggregations().entrySet()) { + for (Map.Entry entry : parent.getAggregations().entrySet()) { Aggregation aggregation = entry.getValue(); if (aggregation.isDistinct() && !aggregation.getFilter().isPresent() && !aggregation.getMask().isPresent()) { - Set inputs = aggregation.getArguments().stream() + Set inputs = aggregation.getArguments().stream() .map(Symbol::from) + .map(context.getSymbolAllocator()::toVariableReference) .collect(toSet()); - Symbol marker = markers.get(inputs); + VariableReferenceExpression marker = markers.get(inputs); if (marker == null) { - marker = context.getSymbolAllocator().newSymbol(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct"); + marker = context.getSymbolAllocator().newVariable(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct"); markers.put(inputs, marker); - ImmutableSet.Builder distinctSymbols = ImmutableSet.builder() + ImmutableSet.Builder distinctVariables = ImmutableSet.builder() .addAll(parent.getGroupingKeys()) .addAll(inputs); - parent.getGroupIdSymbol().ifPresent(distinctSymbols::add); + parent.getGroupIdVariable().ifPresent(distinctVariables::add); subPlan = new MarkDistinctNode( context.getIdAllocator().getNextId(), subPlan, marker, - ImmutableList.copyOf(distinctSymbols.build()), + ImmutableList.copyOf(distinctVariables.build()), Optional.empty()); } @@ -166,7 +168,7 @@ public Result apply(AggregationNode parent, Captures captures, Context context) parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), - parent.getHashSymbol(), - parent.getGroupIdSymbol())); + parent.getHashVariable(), + parent.getGroupIdVariable())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java index 874d29d7bc3ff..1d70824706c0f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java @@ -77,6 +77,7 @@ import static com.facebook.presto.sql.relational.RowExpressionNodeInliner.replaceExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.collect.ImmutableBiMap.toImmutableBiMap; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Sets.intersection; import static java.util.Collections.emptyList; @@ -233,13 +234,13 @@ public Result apply(TableScanNode tableScanNode, Captures captures, Context cont if (metadata.isPushdownFilterSupported(session, tableHandle)) { PushdownFilterResult pushdownFilterResult = metadata.pushdownFilter(session, tableHandle, TRUE); if (pushdownFilterResult.getLayout().getPredicate().isNone()) { - return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), tableScanNode.getOutputSymbols(), ImmutableList.of())); + return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), tableScanNode.getOutputVariables(), ImmutableList.of())); } return Result.ofPlanNode(new TableScanNode( tableScanNode.getId(), pushdownFilterResult.getLayout().getNewTableHandle(), - tableScanNode.getOutputSymbols(), + tableScanNode.getOutputVariables(), tableScanNode.getAssignments(), pushdownFilterResult.getLayout().getPredicate(), TupleDomain.all())); @@ -249,18 +250,18 @@ public Result apply(TableScanNode tableScanNode, Captures captures, Context cont session, tableHandle, Constraint.alwaysTrue(), - Optional.of(tableScanNode.getOutputSymbols().stream() - .map(tableScanNode.getAssignments()::get) + Optional.of(tableScanNode.getOutputVariables().stream() + .map(variable -> tableScanNode.getAssignments().get(variable)) .collect(toImmutableSet()))); if (layout.getLayout().getPredicate().isNone()) { - return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), tableScanNode.getOutputSymbols(), ImmutableList.of())); + return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), tableScanNode.getOutputVariables(), ImmutableList.of())); } return Result.ofPlanNode(new TableScanNode( tableScanNode.getId(), layout.getLayout().getNewTableHandle(), - tableScanNode.getOutputSymbols(), + tableScanNode.getOutputVariables(), tableScanNode.getAssignments(), layout.getLayout().getPredicate(), TupleDomain.all())); @@ -291,21 +292,21 @@ public static PlanNode pushPredicateIntoTableScan( BiMap symbolToColumnMapping = node.getAssignments().entrySet().stream() .collect(toImmutableBiMap( - entry -> new VariableReferenceExpression(entry.getKey().getName(), types.get(entry.getKey())), - entry -> new VariableReferenceExpression(getColumnName(session, metadata, node.getTable(), entry.getValue()), types.get(entry.getKey())))); + entry -> new VariableReferenceExpression(entry.getKey().getName(), types.get(new Symbol(entry.getKey().getName()))), + entry -> new VariableReferenceExpression(getColumnName(session, metadata, node.getTable(), entry.getValue()), types.get(new Symbol(entry.getKey().getName()))))); RowExpression translatedPredicate = replaceExpression(SqlToRowExpressionTranslator.translate(predicate, predicateTypes, ImmutableMap.of(), metadata.getFunctionManager(), metadata.getTypeManager(), session, false), symbolToColumnMapping); PushdownFilterResult pushdownFilterResult = metadata.pushdownFilter(session, node.getTable(), translatedPredicate); TableLayout layout = pushdownFilterResult.getLayout(); if (layout.getPredicate().isNone()) { - return new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of()); + return new ValuesNode(idAllocator.getNextId(), node.getOutputVariables(), ImmutableList.of()); } TableScanNode tableScan = new TableScanNode( node.getId(), layout.getNewTableHandle(), - node.getOutputSymbols(), + node.getOutputVariables(), node.getAssignments(), layout.getPredicate(), TupleDomain.all()); @@ -327,10 +328,10 @@ public static PlanNode pushPredicateIntoTableScan( types); TupleDomain newDomain = decomposedPredicate.getTupleDomain() - .transform(node.getAssignments()::get) + .transform(symbol -> node.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> new Symbol(entry.getKey().getName()), Map.Entry::getValue)).get(symbol)) .intersect(node.getEnforcedConstraint()); - Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); + Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); Constraint constraint; if (pruneWithPredicateExpression) { @@ -344,7 +345,7 @@ public static PlanNode pushPredicateIntoTableScan( deterministicPredicate, // Simplify the tuple domain to avoid creating an expression with too many nodes, // which would be expensive to evaluate in the call to isCandidate below. - domainTranslator.toPredicate(newDomain.simplify().transform(assignments::get)))); + domainTranslator.toPredicate(newDomain.simplify().transform(column -> assignments.containsKey(column) ? new Symbol(assignments.get(column).getName()) : null)))); constraint = new Constraint<>(newDomain, evaluator::isCandidate); } else { @@ -353,7 +354,7 @@ public static PlanNode pushPredicateIntoTableScan( constraint = new Constraint<>(newDomain); } if (constraint.getSummary().isNone()) { - return new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of()); + return new ValuesNode(idAllocator.getNextId(), node.getOutputVariables(), ImmutableList.of()); } // Layouts will be returned in order of the connector's preference @@ -361,18 +362,18 @@ public static PlanNode pushPredicateIntoTableScan( session, node.getTable(), constraint, - Optional.of(node.getOutputSymbols().stream() - .map(node.getAssignments()::get) + Optional.of(node.getOutputVariables().stream() + .map(variable -> node.getAssignments().get(variable)) .collect(toImmutableSet()))); if (layout.getLayout().getPredicate().isNone()) { - return new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of()); + return new ValuesNode(idAllocator.getNextId(), node.getOutputVariables(), ImmutableList.of()); } TableScanNode tableScan = new TableScanNode( node.getId(), layout.getLayout().getNewTableHandle(), - node.getOutputSymbols(), + node.getOutputVariables(), node.getAssignments(), layout.getLayout().getPredicate(), computeEnforced(newDomain, layout.getUnenforcedConstraint())); @@ -386,7 +387,7 @@ public static PlanNode pushPredicateIntoTableScan( // and non-TupleDomain-expressible expressions should be retained. Changing the order can lead // to failures of previously successful queries. Expression resultingPredicate = combineConjuncts( - domainTranslator.toPredicate(layout.getUnenforcedConstraint().transform(assignments::get)), + domainTranslator.toPredicate(layout.getUnenforcedConstraint().transform(column -> new Symbol(assignments.get(column).getName()))), filterNonDeterministicConjuncts(predicate), decomposedPredicate.getRemainingExpression()); @@ -403,11 +404,11 @@ private static String getColumnName(Session session, Metadata metadata, TableHan private static class LayoutConstraintEvaluator { - private final Map assignments; + private final Map assignments; private final ExpressionInterpreter evaluator; private final Set arguments; - public LayoutConstraintEvaluator(Metadata metadata, SqlParser parser, Session session, TypeProvider types, Map assignments, Expression expression) + public LayoutConstraintEvaluator(Metadata metadata, SqlParser parser, Session session, TypeProvider types, Map assignments, Expression expression) { this.assignments = assignments; @@ -415,7 +416,7 @@ public LayoutConstraintEvaluator(Metadata metadata, SqlParser parser, Session se evaluator = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); arguments = SymbolsExtractor.extractUnique(expression).stream() - .map(assignments::get) + .map(symbol -> assignments.entrySet().stream().collect(toImmutableMap(entry -> new Symbol(entry.getKey().getName()), Map.Entry::getValue)).get(symbol)) .collect(toImmutableSet()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java index e2c9333d2c276..604a2921ef1fa 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java @@ -17,7 +17,8 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -60,12 +61,12 @@ public Result apply(ProjectNode parent, Captures captures, Context context) { N targetNode = captures.get(targetCapture); - return pruneInputs(targetNode.getOutputSymbols(), parent.getAssignments().getExpressions()) - .flatMap(prunedOutputs -> this.pushDownProjectOff(context.getIdAllocator(), targetNode, prunedOutputs)) + return pruneInputs(targetNode.getOutputVariables(), parent.getAssignments().getExpressions(), context.getSymbolAllocator().getTypes()) + .flatMap(prunedOutputs -> this.pushDownProjectOff(context.getIdAllocator(), context.getSymbolAllocator(), targetNode, prunedOutputs)) .map(newChild -> parent.replaceChildren(ImmutableList.of(newChild))) .map(Result::ofPlanNode) .orElse(Result.empty()); } - protected abstract Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, N targetNode, Set referencedOutputs); + protected abstract Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, N targetNode, Set referencedOutputs); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationColumns.java index e8d0b0fb7219e..5a963d40dc1a1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.Maps; @@ -36,12 +37,11 @@ public PruneAggregationColumns() @Override protected Optional pushDownProjectOff( PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator, AggregationNode aggregationNode, - Set referencedOutputs) + Set referencedOutputs) { - Map prunedAggregations = Maps.filterKeys( - aggregationNode.getAggregations(), - referencedOutputs::contains); + Map prunedAggregations = Maps.filterKeys(aggregationNode.getAggregations(), referencedOutputs::contains); if (prunedAggregations.size() == aggregationNode.getAggregations().size()) { return Optional.empty(); @@ -54,9 +54,9 @@ protected Optional pushDownProjectOff( aggregationNode.getSource(), prunedAggregations, aggregationNode.getGroupingSets(), - aggregationNode.getPreGroupedSymbols(), + aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), - aggregationNode.getHashSymbol(), - aggregationNode.getGroupIdSymbol())); + aggregationNode.getHashVariable(), + aggregationNode.getGroupIdVariable())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationSourceColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationSourceColumns.java index 4d1e797d83af1..d5fb77315a4c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationSourceColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationSourceColumns.java @@ -15,7 +15,8 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -42,11 +43,11 @@ public Pattern getPattern() @Override public Result apply(AggregationNode aggregationNode, Captures captures, Context context) { - Set requiredInputs = Streams.concat( + Set requiredInputs = Streams.concat( aggregationNode.getGroupingKeys().stream(), - aggregationNode.getHashSymbol().map(Stream::of).orElse(Stream.empty()), + aggregationNode.getHashVariable().map(Stream::of).orElse(Stream.empty()), aggregationNode.getAggregations().values().stream() - .flatMap(PruneAggregationSourceColumns::getAggregationInputs)) + .flatMap(aggregation -> getAggregationInputs(aggregation, context.getSymbolAllocator().getTypes()))) .collect(toImmutableSet()); return restrictChildOutputs(context.getIdAllocator(), aggregationNode, requiredInputs) @@ -54,10 +55,10 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context .orElse(Result.empty()); } - private static Stream getAggregationInputs(AggregationNode.Aggregation aggregation) + private static Stream getAggregationInputs(AggregationNode.Aggregation aggregation, TypeProvider types) { return Streams.concat( - AggregationNodeUtils.extractUnique(aggregation).stream(), + AggregationNodeUtils.extractUniqueVariables(aggregation, types).stream(), aggregation.getMask().map(Stream::of).orElse(Stream.empty())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java index 6b9f341de536b..5cb1adee51a35 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java @@ -17,7 +17,7 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.function.StandardFunctionResolution; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ValuesNode; @@ -57,11 +57,11 @@ public Pattern getPattern() @Override public Result apply(AggregationNode parent, Captures captures, Context context) { - if (!parent.hasDefaultOutput() || parent.getOutputSymbols().size() != 1) { + if (!parent.hasDefaultOutput() || parent.getOutputVariables().size() != 1) { return Result.empty(); } - Map assignments = parent.getAggregations(); - for (Map.Entry entry : assignments.entrySet()) { + Map assignments = parent.getAggregations(); + for (Map.Entry entry : assignments.entrySet()) { AggregationNode.Aggregation aggregation = entry.getValue(); requireNonNull(aggregation, "aggregation is null"); if (!functionResolution.isCountFunction(aggregation.getFunctionHandle()) || !aggregation.getArguments().isEmpty()) { @@ -71,7 +71,7 @@ public Result apply(AggregationNode parent, Captures captures, Context context) if (!assignments.isEmpty() && isScalar(parent.getSource(), context.getLookup())) { return Result.ofPlanNode(new ValuesNode( parent.getId(), - parent.getOutputSymbols(), + parent.getOutputVariables(), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT))))); } return Result.empty(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java index 80aea1baa8d67..084910c7b7e9f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.ImmutableList; @@ -37,7 +38,7 @@ public PruneCrossJoinColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, JoinNode joinNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, JoinNode joinNode, Set referencedOutputs) { Optional newLeft = restrictOutputs(idAllocator, joinNode.getLeft(), referencedOutputs); Optional newRight = restrictOutputs(idAllocator, joinNode.getRight(), referencedOutputs); @@ -46,9 +47,9 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, return Optional.empty(); } - ImmutableList.Builder outputSymbolBuilder = ImmutableList.builder(); - outputSymbolBuilder.addAll(newLeft.orElse(joinNode.getLeft()).getOutputSymbols()); - outputSymbolBuilder.addAll(newRight.orElse(joinNode.getRight()).getOutputSymbols()); + ImmutableList.Builder outputVariableBuilder = ImmutableList.builder(); + outputVariableBuilder.addAll(newLeft.orElse(joinNode.getLeft()).getOutputVariables()); + outputVariableBuilder.addAll(newRight.orElse(joinNode.getRight()).getOutputVariables()); return Optional.of(new JoinNode( idAllocator.getNextId(), @@ -56,10 +57,10 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, newLeft.orElse(joinNode.getLeft()), newRight.orElse(joinNode.getRight()), joinNode.getCriteria(), - outputSymbolBuilder.build(), + outputVariableBuilder.build(), joinNode.getFilter(), - joinNode.getLeftHashSymbol(), - joinNode.getRightHashSymbol(), + joinNode.getLeftHashVariable(), + joinNode.getRightHashVariable(), joinNode.getDistributionType())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneFilterColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneFilterColumns.java index 71b154658e6c2..3c0e42684e39f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneFilterColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneFilterColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -37,11 +38,11 @@ public PruneFilterColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, FilterNode filterNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, FilterNode filterNode, Set referencedOutputs) { - Set prunedFilterInputs = Streams.concat( + Set prunedFilterInputs = Streams.concat( referencedOutputs.stream(), - SymbolsExtractor.extractUnique(castToExpression(filterNode.getPredicate())).stream()) + SymbolsExtractor.extractUniqueVariable(castToExpression(filterNode.getPredicate()), symbolAllocator.getTypes()).stream()) .collect(toImmutableSet()); return restrictChildOutputs(idAllocator, filterNode, prunedFilterInputs); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneIndexSourceColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneIndexSourceColumns.java index 14ada09c73192..393043d30d6d9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneIndexSourceColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneIndexSourceColumns.java @@ -16,7 +16,8 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.Maps; @@ -39,19 +40,19 @@ public PruneIndexSourceColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, IndexSourceNode indexSourceNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, IndexSourceNode indexSourceNode, Set referencedOutputs) { - Set prunedLookupSymbols = indexSourceNode.getLookupSymbols().stream() + Set prunedLookupSymbols = indexSourceNode.getLookupVariables().stream() .filter(referencedOutputs::contains) .collect(toImmutableSet()); - Map prunedAssignments = Maps.filterEntries( + Map prunedAssignments = Maps.filterEntries( indexSourceNode.getAssignments(), entry -> referencedOutputs.contains(entry.getKey()) || tupleDomainReferencesColumnHandle(indexSourceNode.getCurrentConstraint(), entry.getValue())); - List prunedOutputList = - indexSourceNode.getOutputSymbols().stream() + List prunedOutputList = + indexSourceNode.getOutputVariables().stream() .filter(referencedOutputs::contains) .collect(toImmutableList()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java index 3d861990c6b75..0525fa63afd33 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java @@ -15,8 +15,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.relational.OriginalExpressionUtils; @@ -24,12 +23,13 @@ import java.util.Set; +import static com.facebook.presto.sql.planner.SymbolsExtractor.extractUniqueVariable; import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictChildOutputs; import static com.facebook.presto.sql.planner.plan.Patterns.join; import static com.google.common.base.Predicates.not; /** - * Non-Cross joins support output symbol selection, so make any project-off of child columns explicit in project nodes. + * Non-Cross joins support output variable selection, so make any project-off of child columns explicit in project nodes. */ public class PruneJoinChildrenColumns implements Rule @@ -46,31 +46,31 @@ public Pattern getPattern() @Override public Result apply(JoinNode joinNode, Captures captures, Context context) { - Set globallyUsableInputs = ImmutableSet.builder() - .addAll(joinNode.getOutputSymbols()) + Set globallyUsableInputs = ImmutableSet.builder() + .addAll(joinNode.getOutputVariables()) .addAll( joinNode.getFilter() .map(OriginalExpressionUtils::castToExpression) - .map(SymbolsExtractor::extractUnique) + .map(expression -> extractUniqueVariable(expression, context.getSymbolAllocator().getTypes())) .orElse(ImmutableSet.of())) .build(); - Set leftUsableInputs = ImmutableSet.builder() + Set leftUsableInputs = ImmutableSet.builder() .addAll(globallyUsableInputs) .addAll( joinNode.getCriteria().stream() .map(JoinNode.EquiJoinClause::getLeft) .iterator()) - .addAll(joinNode.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of())) + .addAll(joinNode.getLeftHashVariable().map(ImmutableSet::of).orElse(ImmutableSet.of())) .build(); - Set rightUsableInputs = ImmutableSet.builder() + Set rightUsableInputs = ImmutableSet.builder() .addAll(globallyUsableInputs) .addAll( joinNode.getCriteria().stream() .map(JoinNode.EquiJoinClause::getRight) .iterator()) - .addAll(joinNode.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of())) + .addAll(joinNode.getRightHashVariable().map(ImmutableSet::of).orElse(ImmutableSet.of())) .build(); return restrictChildOutputs(context.getIdAllocator(), joinNode, leftUsableInputs, rightUsableInputs) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java index 92e9efcc790e7..7b50430b25dbe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -37,7 +38,7 @@ public PruneJoinColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, JoinNode joinNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, JoinNode joinNode, Set referencedOutputs) { return Optional.of( new JoinNode( @@ -46,10 +47,10 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), - filteredCopy(joinNode.getOutputSymbols(), referencedOutputs::contains), + filteredCopy(joinNode.getOutputVariables(), referencedOutputs::contains), joinNode.getFilter(), - joinNode.getLeftHashSymbol(), - joinNode.getRightHashSymbol(), + joinNode.getLeftHashVariable(), + joinNode.getRightHashVariable(), joinNode.getDistributionType())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneLimitColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneLimitColumns.java index 6ec5e6fe26194..b450f7e4f540c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneLimitColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneLimitColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -33,7 +34,7 @@ public PruneLimitColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, LimitNode limitNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, LimitNode limitNode, Set referencedOutputs) { return restrictChildOutputs(idAllocator, limitNode, referencedOutputs); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java index 3d78d1364b2d5..16357244faefa 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.Streams; @@ -36,17 +37,17 @@ public PruneMarkDistinctColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, MarkDistinctNode markDistinctNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, MarkDistinctNode markDistinctNode, Set referencedOutputs) { - if (!referencedOutputs.contains(markDistinctNode.getMarkerSymbol())) { + if (!referencedOutputs.contains(markDistinctNode.getMarkerVariable())) { return Optional.of(markDistinctNode.getSource()); } - Set requiredInputs = Streams.concat( + Set requiredInputs = Streams.concat( referencedOutputs.stream() - .filter(symbol -> !symbol.equals(markDistinctNode.getMarkerSymbol())), - markDistinctNode.getDistinctSymbols().stream(), - markDistinctNode.getHashSymbol().map(Stream::of).orElse(Stream.empty())) + .filter(variable -> !variable.equals(markDistinctNode.getMarkerVariable())), + markDistinctNode.getDistinctVariables().stream(), + markDistinctNode.getHashVariable().map(Stream::of).orElse(Stream.empty())) .collect(toImmutableSet()); return restrictChildOutputs(idAllocator, markDistinctNode, requiredInputs); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java index de224e5d0f1b5..f5c785656b0c4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java @@ -16,7 +16,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.google.common.collect.ImmutableMap; @@ -53,8 +53,8 @@ public Result apply(AggregationNode node, Captures captures, Context context) } boolean anyRewritten = false; - ImmutableMap.Builder aggregations = ImmutableMap.builder(); - for (Map.Entry entry : node.getAggregations().entrySet()) { + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : node.getAggregations().entrySet()) { Aggregation aggregation = entry.getValue(); if (!aggregation.getOrderBy().isPresent()) { aggregations.put(entry); @@ -84,9 +84,9 @@ else if (functionManager.getAggregateFunctionImplementation(aggregation.getFunct node.getSource(), aggregations.build(), node.getGroupingSets(), - node.getPreGroupedSymbols(), + node.getPreGroupedVariables(), node.getStep(), - node.getHashSymbol(), - node.getGroupIdSymbol())); + node.getHashVariable(), + node.getGroupIdVariable())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOutputColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOutputColumns.java index ea91f74176b43..1950a706282b6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOutputColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOutputColumns.java @@ -39,7 +39,7 @@ public Result apply(OutputNode outputNode, Captures captures, Context context) return restrictChildOutputs( context.getIdAllocator(), outputNode, - ImmutableSet.copyOf(outputNode.getOutputSymbols())) + ImmutableSet.copyOf(outputNode.getOutputVariables())) .map(Result::ofPlanNode) .orElse(Result.empty()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneProjectColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneProjectColumns.java index d5080b3bf3822..5d87140d4ac80 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneProjectColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneProjectColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -34,8 +35,9 @@ public PruneProjectColumns() @Override protected Optional pushDownProjectOff( PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator, ProjectNode childProjectNode, - Set referencedOutputs) + Set referencedOutputs) { return Optional.of( new ProjectNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java index b34cf7a78a343..8164d4ab29ff8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.google.common.collect.ImmutableList; @@ -37,17 +38,17 @@ public PruneSemiJoinColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SemiJoinNode semiJoinNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, SemiJoinNode semiJoinNode, Set referencedOutputs) { if (!referencedOutputs.contains(semiJoinNode.getSemiJoinOutput())) { return Optional.of(semiJoinNode.getSource()); } - Set requiredSourceInputs = Streams.concat( + Set requiredSourceInputs = Streams.concat( referencedOutputs.stream() - .filter(symbol -> !symbol.equals(semiJoinNode.getSemiJoinOutput())), - Stream.of(semiJoinNode.getSourceJoinSymbol()), - semiJoinNode.getSourceHashSymbol().map(Stream::of).orElse(Stream.empty())) + .filter(variable -> !variable.equals(semiJoinNode.getSemiJoinOutput())), + Stream.of(semiJoinNode.getSourceJoinVariable()), + semiJoinNode.getSourceHashVariable().map(Stream::of).orElse(Stream.empty())) .collect(toImmutableSet()); return restrictOutputs(idAllocator, semiJoinNode.getSource(), requiredSourceInputs) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java index a0b9b46b90cb1..77b17654dd22e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java @@ -15,7 +15,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.google.common.collect.ImmutableList; @@ -42,9 +42,9 @@ public Pattern getPattern() @Override public Result apply(SemiJoinNode semiJoinNode, Captures captures, Context context) { - Set requiredFilteringSourceInputs = Streams.concat( - Stream.of(semiJoinNode.getFilteringSourceJoinSymbol()), - semiJoinNode.getFilteringSourceHashSymbol().map(Stream::of).orElse(Stream.empty())) + Set requiredFilteringSourceInputs = Streams.concat( + Stream.of(semiJoinNode.getFilteringSourceJoinVariable()), + semiJoinNode.getFilteringSourceHashVariable().map(Stream::of).orElse(Stream.empty())) .collect(toImmutableSet()); return restrictOutputs(context.getIdAllocator(), semiJoinNode.getFilteringSource(), requiredFilteringSourceInputs) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java index 71d41e7ccbf33..7c6ef0a6ab060 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TableScanNode; @@ -34,13 +35,13 @@ public PruneTableScanColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, TableScanNode tableScanNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, TableScanNode tableScanNode, Set referencedOutputs) { return Optional.of( new TableScanNode( tableScanNode.getId(), tableScanNode.getTable(), - filteredCopy(tableScanNode.getOutputSymbols(), referencedOutputs::contains), + filteredCopy(tableScanNode.getOutputVariables(), referencedOutputs::contains), filterKeys(tableScanNode.getAssignments(), referencedOutputs::contains), tableScanNode.getCurrentConstraint(), tableScanNode.getEnforcedConstraint())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTopNColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTopNColumns.java index 808226e31bd6a..b6a95d25b75ce 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTopNColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTopNColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TopNNode; import com.google.common.collect.Streams; @@ -35,9 +36,9 @@ public PruneTopNColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, TopNNode topNNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, TopNNode topNNode, Set referencedOutputs) { - Set prunedTopNInputs = Streams.concat( + Set prunedTopNInputs = Streams.concat( referencedOutputs.stream(), topNNode.getOrderingScheme().getOrderBy().stream()) .collect(toImmutableSet()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java index a3e77e557439d..6489a44c80aa4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java @@ -15,7 +15,8 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.google.common.collect.ImmutableList; @@ -38,14 +39,16 @@ public PruneValuesColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, ValuesNode valuesNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, ValuesNode valuesNode, Set referencedOutputs) { - List newOutputs = filteredCopy(valuesNode.getOutputSymbols(), referencedOutputs::contains); + List newOutputs = filteredCopy(valuesNode.getOutputVariables(), referencedOutputs::contains); + + List newOutputVariables = filteredCopy(valuesNode.getOutputVariables(), referencedOutputs::contains); // for each output of project, the corresponding column in the values node int[] mapping = new int[newOutputs.size()]; for (int i = 0; i < mapping.length; i++) { - mapping[i] = valuesNode.getOutputSymbols().indexOf(newOutputs.get(i)); + mapping[i] = valuesNode.getOutputVariables().indexOf(newOutputs.get(i)); } ImmutableList.Builder> rowsBuilder = ImmutableList.builder(); @@ -55,6 +58,6 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, .collect(Collectors.toList())); } - return Optional.of(new ValuesNode(valuesNode.getId(), newOutputs, rowsBuilder.build())); + return Optional.of(new ValuesNode(valuesNode.getId(), newOutputVariables, rowsBuilder.build())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java index 592be70547dd1..364ddcc359190 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneWindowColumns.java @@ -14,7 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.optimizations.WindowNodeUtil; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -37,18 +38,16 @@ public PruneWindowColumns() } @Override - protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, WindowNode windowNode, Set referencedOutputs) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, WindowNode windowNode, Set referencedOutputs) { - Map referencedFunctions = Maps.filterKeys( - windowNode.getWindowFunctions(), - referencedOutputs::contains); + Map referencedFunctions = Maps.filterKeys(windowNode.getWindowFunctions(), referencedOutputs::contains); if (referencedFunctions.isEmpty()) { return Optional.of(windowNode.getSource()); } - ImmutableSet.Builder referencedInputs = ImmutableSet.builder() - .addAll(windowNode.getSource().getOutputSymbols().stream() + ImmutableSet.Builder referencedInputs = ImmutableSet.builder() + .addAll(windowNode.getSource().getOutputVariables().stream() .filter(referencedOutputs::contains) .iterator()) .addAll(windowNode.getPartitionBy()); @@ -57,10 +56,10 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, orderingScheme -> orderingScheme .getOrderBy() .forEach(referencedInputs::add)); - windowNode.getHashSymbol().ifPresent(referencedInputs::add); + windowNode.getHashVariable().ifPresent(referencedInputs::add); for (WindowNode.Function windowFunction : referencedFunctions.values()) { - referencedInputs.addAll(WindowNodeUtil.extractWindowFunctionUnique(windowFunction)); + referencedInputs.addAll(symbolAllocator.toVariableReferences(WindowNodeUtil.extractWindowFunctionUnique(windowFunction))); windowFunction.getFrame().getStartValue().ifPresent(referencedInputs::add); windowFunction.getFrame().getEndValue().ifPresent(referencedInputs::add); } @@ -71,11 +70,11 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, .orElse(windowNode.getSource()), windowNode.getSpecification(), referencedFunctions, - windowNode.getHashSymbol(), + windowNode.getHashVariable(), windowNode.getPrePartitionedInputs(), windowNode.getPreSortedOrderPrefix()); - if (prunedWindowNode.getOutputSymbols().size() == windowNode.getOutputSymbols().size()) { + if (prunedWindowNode.getOutputVariables().size() == windowNode.getOutputVariables().size()) { // Neither function pruning nor input pruning was successful. return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index fdf0b8ccf40b9..e13187154f35f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -21,8 +21,8 @@ import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; @@ -47,7 +47,7 @@ import static com.facebook.presto.SystemSessionProperties.shouldPushAggregationThroughJoin; import static com.facebook.presto.matching.Capture.newCapture; -import static com.facebook.presto.sql.planner.ExpressionSymbolInliner.inlineSymbols; +import static com.facebook.presto.sql.planner.ExpressionVariableInliner.inlineVariables; import static com.facebook.presto.sql.planner.optimizations.DistinctOutputQueryUtil.isDistinct; import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; @@ -130,12 +130,12 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont if (join.getFilter().isPresent() || !(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT) - || !groupsOnAllColumns(aggregation, getOuterTable(join).getOutputSymbols()) + || !groupsOnAllColumns(aggregation, getOuterTable(join).getOutputVariables()) || !isDistinct(context.getLookup().resolve(getOuterTable(join)), context.getLookup()::resolve)) { return Result.empty(); } - List groupingKeys = join.getCriteria().stream() + List groupingKeys = join.getCriteria().stream() .map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight) .collect(toImmutableList()); AggregationNode rewrittenAggregation = new AggregationNode( @@ -145,8 +145,8 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol()); + aggregation.getHashVariable(), + aggregation.getGroupIdVariable()); JoinNode rewrittenJoin; if (join.getType() == JoinNode.Type.LEFT) { @@ -156,13 +156,13 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont join.getLeft(), rewrittenAggregation, join.getCriteria(), - ImmutableList.builder() - .addAll(join.getLeft().getOutputSymbols()) + ImmutableList.builder() + .addAll(join.getLeft().getOutputVariables()) .addAll(rewrittenAggregation.getAggregations().keySet()) .build(), join.getFilter(), - join.getLeftHashSymbol(), - join.getRightHashSymbol(), + join.getLeftHashVariable(), + join.getRightHashVariable(), join.getDistributionType()); } else { @@ -172,13 +172,13 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont rewrittenAggregation, join.getRight(), join.getCriteria(), - ImmutableList.builder() + ImmutableList.builder() .addAll(rewrittenAggregation.getAggregations().keySet()) - .addAll(join.getRight().getOutputSymbols()) + .addAll(join.getRight().getOutputVariables()) .build(), join.getFilter(), - join.getLeftHashSymbol(), - join.getRightHashSymbol(), + join.getLeftHashVariable(), + join.getRightHashVariable(), join.getDistributionType()); } @@ -216,7 +216,7 @@ private static PlanNode getOuterTable(JoinNode join) return outerNode; } - private static boolean groupsOnAllColumns(AggregationNode node, List columns) + private static boolean groupsOnAllColumns(AggregationNode node, List columns) { return new HashSet<>(node.getGroupingKeys()).equals(new HashSet<>(columns)); } @@ -242,7 +242,7 @@ private Optional coalesceWithNullAggregation(AggregationNode aggregati MappedAggregationInfo aggregationOverNullInfo = aggregationOverNullInfoResultNode.get(); AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation(); - Map sourceAggregationToOverNullMapping = aggregationOverNullInfo.getSymbolMapping(); + Map sourceAggregationToOverNullMapping = aggregationOverNullInfo.getVariableMapping(); // Do a cross join with the aggregation over null JoinNode crossJoin = new JoinNode( @@ -251,9 +251,9 @@ private Optional coalesceWithNullAggregation(AggregationNode aggregati outerJoin, aggregationOverNull, ImmutableList.of(), - ImmutableList.builder() - .addAll(outerJoin.getOutputSymbols()) - .addAll(aggregationOverNull.getOutputSymbols()) + ImmutableList.builder() + .addAll(outerJoin.getOutputVariables()) + .addAll(aggregationOverNull.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), @@ -262,12 +262,12 @@ private Optional coalesceWithNullAggregation(AggregationNode aggregati // Add coalesce expressions for all aggregation functions Assignments.Builder assignmentsBuilder = Assignments.builder(); - for (Symbol symbol : outerJoin.getOutputSymbols()) { - if (aggregationNode.getAggregations().containsKey(symbol)) { - assignmentsBuilder.put(symbol, new CoalesceExpression(symbol.toSymbolReference(), sourceAggregationToOverNullMapping.get(symbol).toSymbolReference())); + for (VariableReferenceExpression variable : outerJoin.getOutputVariables()) { + if (aggregationNode.getAggregations().keySet().contains(variable)) { + assignmentsBuilder.put(variable, new CoalesceExpression(new SymbolReference(variable.getName()), new SymbolReference(sourceAggregationToOverNullMapping.get(variable).getName()))); } else { - assignmentsBuilder.put(symbol, symbol.toSymbolReference()); + assignmentsBuilder.put(variable, new SymbolReference(variable.getName())); } } return Optional.of(new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build())); @@ -279,47 +279,48 @@ private Optional createAggregationOverNull(AggregationNod // Map the output symbols from the referenceAggregation's source // to symbol references for the new values node. NullLiteral nullLiteral = new NullLiteral(); - ImmutableList.Builder nullSymbols = ImmutableList.builder(); + ImmutableList.Builder nullVariables = ImmutableList.builder(); ImmutableList.Builder nullLiterals = ImmutableList.builder(); - ImmutableMap.Builder sourcesSymbolMappingBuilder = ImmutableMap.builder(); - for (Symbol sourceSymbol : referenceAggregation.getSource().getOutputSymbols()) { + ImmutableMap.Builder sourcesVariableMappingBuilder = ImmutableMap.builder(); + for (VariableReferenceExpression sourceVariable : referenceAggregation.getSource().getOutputVariables()) { nullLiterals.add(castToRowExpression(nullLiteral)); - Symbol nullSymbol = symbolAllocator.newSymbol(nullLiteral, symbolAllocator.getTypes().get(sourceSymbol)); - nullSymbols.add(nullSymbol); - sourcesSymbolMappingBuilder.put(sourceSymbol, nullSymbol.toSymbolReference()); + VariableReferenceExpression nullVariable = symbolAllocator.newVariable(nullLiteral, sourceVariable.getType()); + nullVariables.add(nullVariable); + // TODO The type should be from sourceVariable.getType + sourcesVariableMappingBuilder.put(sourceVariable, new SymbolReference(nullVariable.getName())); } ValuesNode nullRow = new ValuesNode( idAllocator.getNextId(), - nullSymbols.build(), + nullVariables.build(), ImmutableList.of(nullLiterals.build())); - Map sourcesSymbolMapping = sourcesSymbolMappingBuilder.build(); + Map sourcesVariableMapping = sourcesVariableMappingBuilder.build(); // For each aggregation function in the reference node, create a corresponding aggregation function // that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the // symbols in these new aggregations. - ImmutableMap.Builder aggregationsSymbolMappingBuilder = ImmutableMap.builder(); - ImmutableMap.Builder aggregationsOverNullBuilder = ImmutableMap.builder(); - for (Map.Entry entry : referenceAggregation.getAggregations().entrySet()) { - Symbol aggregationSymbol = entry.getKey(); + ImmutableMap.Builder aggregationsVariableMappingBuilder = ImmutableMap.builder(); + ImmutableMap.Builder aggregationsOverNullBuilder = ImmutableMap.builder(); + for (Map.Entry entry : referenceAggregation.getAggregations().entrySet()) { + VariableReferenceExpression aggregationVariable = entry.getKey(); AggregationNode.Aggregation aggregation = entry.getValue(); - if (!isUsingSymbols(aggregation, sourcesSymbolMapping.keySet())) { + if (!isUsingVariables(aggregation, sourcesVariableMapping.keySet())) { return Optional.empty(); } AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation( aggregation.getFunctionHandle(), - aggregation.getArguments().stream().map(argument -> inlineSymbols(sourcesSymbolMapping, argument)).collect(toImmutableList()), - aggregation.getFilter().map(filter -> inlineSymbols(sourcesSymbolMapping, filter)), - aggregation.getOrderBy().map(orderBy -> inlineOrderBySymbols(sourcesSymbolMapping, orderBy)), + aggregation.getArguments().stream().map(argument -> inlineVariables(sourcesVariableMapping, argument, symbolAllocator.getTypes())).collect(toImmutableList()), + aggregation.getFilter().map(filter -> inlineVariables(sourcesVariableMapping, filter, symbolAllocator.getTypes())), + aggregation.getOrderBy().map(orderBy -> inlineOrderByVariables(sourcesVariableMapping, orderBy)), aggregation.isDistinct(), - aggregation.getMask().map(x -> Symbol.from(sourcesSymbolMapping.get(x)))); + aggregation.getMask().map(x -> new VariableReferenceExpression(sourcesVariableMapping.get(x).getName(), x.getType()))); String functionName = functionManager.getFunctionMetadata(overNullAggregation.getFunctionHandle()).getName(); - Symbol overNullSymbol = symbolAllocator.newSymbol(functionName, symbolAllocator.getTypes().get(aggregationSymbol)); - aggregationsOverNullBuilder.put(overNullSymbol, overNullAggregation); - aggregationsSymbolMappingBuilder.put(aggregationSymbol, overNullSymbol); + VariableReferenceExpression overNull = symbolAllocator.newVariable(functionName, aggregationVariable.getType()); + aggregationsOverNullBuilder.put(overNull, overNullAggregation); + aggregationsVariableMappingBuilder.put(aggregationVariable, overNull); } - Map aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.build(); + Map aggregationsSymbolMapping = aggregationsVariableMappingBuilder.build(); // create an aggregation node whose source is the null row. AggregationNode aggregationOverNullRow = new AggregationNode( @@ -335,41 +336,42 @@ private Optional createAggregationOverNull(AggregationNod return Optional.of(new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping)); } - private static OrderingScheme inlineOrderBySymbols(Map symbolMapping, OrderingScheme orderingScheme) + private static OrderingScheme inlineOrderByVariables(Map variableMapping, OrderingScheme orderingScheme) { // This is a logic expanded from ExpressionTreeRewriter::rewriteSortItems - ImmutableList.Builder orderBy = ImmutableList.builder(); - ImmutableMap.Builder ordering = new ImmutableMap.Builder<>(); - for (Symbol symbol : orderingScheme.getOrderBy()) { - Symbol translated = Symbol.from(symbolMapping.get(symbol)); + ImmutableList.Builder orderBy = ImmutableList.builder(); + ImmutableMap.Builder ordering = new ImmutableMap.Builder<>(); + for (VariableReferenceExpression variable : orderingScheme.getOrderBy()) { + VariableReferenceExpression translated = new VariableReferenceExpression(variableMapping.get(variable).getName(), variable.getType()); orderBy.add(translated); - ordering.put(translated, orderingScheme.getOrdering(symbol)); + ordering.put(translated, orderingScheme.getOrdering(variable)); } return new OrderingScheme(orderBy.build(), ordering.build()); } - private static boolean isUsingSymbols(AggregationNode.Aggregation aggregation, Set sourceSymbols) + private static boolean isUsingVariables(AggregationNode.Aggregation aggregation, Set sourceVariables) { List functionArguments = aggregation.getArguments(); - return sourceSymbols.stream() - .map(Symbol::toSymbolReference) + return sourceVariables.stream() + .map(VariableReferenceExpression::getName) + .map(SymbolReference::new) .anyMatch(functionArguments::contains); } private static class MappedAggregationInfo { private final AggregationNode aggregationNode; - private final Map symbolMapping; + private final Map variableMapping; - public MappedAggregationInfo(AggregationNode aggregationNode, Map symbolMapping) + public MappedAggregationInfo(AggregationNode aggregationNode, Map variableMapping) { this.aggregationNode = aggregationNode; - this.symbolMapping = symbolMapping; + this.variableMapping = variableMapping; } - public Map getSymbolMapping() + public Map getVariableMapping() { - return symbolMapping; + return variableMapping; } public AggregationNode getAggregation() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index a54853cdc43c9..45d527a3fe987 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -19,9 +19,9 @@ import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.SymbolMapper; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -31,6 +31,7 @@ import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import java.util.ArrayList; @@ -113,12 +114,12 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context if (exchangeNode.getType() == REPARTITION) { // if partitioning columns are not a subset of grouping keys, // we can't push this through - List partitioningColumns = exchangeNode.getPartitioningScheme() + List partitioningColumns = exchangeNode.getPartitioningScheme() .getPartitioning() .getArguments() .stream() .filter(Partitioning.ArgumentBinding::isVariable) - .map(Partitioning.ArgumentBinding::getColumn) + .map(Partitioning.ArgumentBinding::getVariableReference) .collect(Collectors.toList()); if (!aggregationNode.getGroupingKeys().containsAll(partitioningColumns)) { @@ -127,7 +128,7 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context } // currently, we only support plans that don't use pre-computed hash functions - if (aggregationNode.getHashSymbol().isPresent() || exchangeNode.getPartitioningScheme().getHashColumn().isPresent()) { + if (aggregationNode.getHashVariable().isPresent() || exchangeNode.getPartitioningScheme().getHashColumn().isPresent()) { return Result.empty(); } @@ -150,9 +151,9 @@ private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, PlanNode source = exchange.getSources().get(i); SymbolMapper.Builder mappingsBuilder = SymbolMapper.builder(); - for (int outputIndex = 0; outputIndex < exchange.getOutputSymbols().size(); outputIndex++) { - Symbol output = exchange.getOutputSymbols().get(outputIndex); - Symbol input = exchange.getInputs().get(i).get(outputIndex); + for (int outputIndex = 0; outputIndex < exchange.getOutputVariables().size(); outputIndex++) { + VariableReferenceExpression output = exchange.getOutputVariables().get(outputIndex); + VariableReferenceExpression input = exchange.getInputs().get(i).get(outputIndex); if (!output.equals(input)) { mappingsBuilder.put(output, input); } @@ -163,22 +164,22 @@ private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, Assignments.Builder assignments = Assignments.builder(); - for (Symbol output : aggregation.getOutputSymbols()) { - Symbol input = symbolMapper.map(output); - assignments.put(output, input.toSymbolReference()); + for (VariableReferenceExpression output : aggregation.getOutputVariables()) { + VariableReferenceExpression input = symbolMapper.map(output); + assignments.put(output, new SymbolReference(input.getName())); } partials.add(new ProjectNode(context.getIdAllocator().getNextId(), mappedPartial, assignments.build())); } for (PlanNode node : partials) { - verify(aggregation.getOutputSymbols().equals(node.getOutputSymbols())); + verify(aggregation.getOutputVariables().equals(node.getOutputVariables())); } - // Since this exchange source is now guaranteed to have the same symbols as the inputs to the the partial // aggregation, we don't need to rewrite symbols in the partitioning function + List aggregationOutputs = aggregation.getOutputVariables(); PartitioningScheme partitioning = new PartitioningScheme( exchange.getPartitioningScheme().getPartitioning(), - aggregation.getOutputSymbols(), + aggregationOutputs, exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition()); @@ -189,24 +190,24 @@ private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, exchange.getScope(), partitioning, partials, - ImmutableList.copyOf(Collections.nCopies(partials.size(), aggregation.getOutputSymbols())), + ImmutableList.copyOf(Collections.nCopies(partials.size(), aggregationOutputs)), Optional.empty()); } private PlanNode split(AggregationNode node, Context context) { // otherwise, add a partial and final with an exchange in between - Map intermediateAggregation = new HashMap<>(); - Map finalAggregation = new HashMap<>(); - for (Map.Entry entry : node.getAggregations().entrySet()) { + Map intermediateAggregation = new HashMap<>(); + Map finalAggregation = new HashMap<>(); + for (Map.Entry entry : node.getAggregations().entrySet()) { AggregationNode.Aggregation originalAggregation = entry.getValue(); String functionName = functionManager.getFunctionMetadata(originalAggregation.getFunctionHandle()).getName(); FunctionHandle functionHandle = originalAggregation.getFunctionHandle(); InternalAggregationFunction function = functionManager.getAggregateFunctionImplementation(functionHandle); - Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(functionName, function.getIntermediateType()); + VariableReferenceExpression intermediateVariable = context.getSymbolAllocator().newVariable(functionName, function.getIntermediateType()); checkState(!originalAggregation.getOrderBy().isPresent(), "Aggregate with ORDER BY does not support partial aggregation"); - intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation( + intermediateAggregation.put(intermediateVariable, new AggregationNode.Aggregation( functionHandle, originalAggregation.getArguments(), originalAggregation.getFilter(), @@ -217,10 +218,9 @@ private PlanNode split(AggregationNode node, Context context) // rewrite final aggregation in terms of intermediate function finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation( - functionHandle, ImmutableList.builder() - .add(intermediateSymbol.toSymbolReference()) + .add(new SymbolReference(intermediateVariable.getName())) .addAll(originalAggregation.getArguments().stream() .filter(LambdaExpression.class::isInstance) .collect(toImmutableList())) @@ -240,8 +240,8 @@ private PlanNode split(AggregationNode node, Context context) // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here. ImmutableList.of(), PARTIAL, - node.getHashSymbol(), - node.getGroupIdSymbol()); + node.getHashVariable(), + node.getGroupIdVariable()); return new AggregationNode( node.getId(), @@ -252,7 +252,7 @@ private PlanNode split(AggregationNode node, Context context) // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here. ImmutableList.of(), FINAL, - node.getHashSymbol(), - node.getGroupIdSymbol()); + node.getHashVariable(), + node.getGroupIdVariable()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java index 48247432ddc67..c27bb5b2ecff0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java @@ -17,8 +17,10 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -61,7 +63,7 @@ private static boolean isSupportedAggregationNode(AggregationNode aggregationNod return false; } - if (aggregationNode.getHashSymbol().isPresent()) { + if (aggregationNode.getHashVariable().isPresent()) { // TODO: add support for hash symbol in aggregation node return false; } @@ -90,66 +92,67 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context } // TODO: leave partial aggregation above Join? - if (allAggregationsOn(aggregationNode.getAggregations(), joinNode.getLeft().getOutputSymbols())) { + if (allAggregationsOn(aggregationNode.getAggregations(), joinNode.getLeft().getOutputVariables())) { return Result.ofPlanNode(pushPartialToLeftChild(aggregationNode, joinNode, context)); } - else if (allAggregationsOn(aggregationNode.getAggregations(), joinNode.getRight().getOutputSymbols())) { + else if (allAggregationsOn(aggregationNode.getAggregations(), joinNode.getRight().getOutputVariables())) { return Result.ofPlanNode(pushPartialToRightChild(aggregationNode, joinNode, context)); } return Result.empty(); } - private boolean allAggregationsOn(Map aggregations, List symbols) + private boolean allAggregationsOn(Map aggregations, List variables) { - Set inputs = aggregations.values() + Set inputNames = aggregations.values() .stream() .map(AggregationNodeUtils::extractUnique) .flatMap(Set::stream) + .map(Symbol::getName) .collect(toImmutableSet()); - return symbols.containsAll(inputs); + return variables.stream().map(VariableReferenceExpression::getName).collect(toImmutableSet()).containsAll(inputNames); } private PlanNode pushPartialToLeftChild(AggregationNode node, JoinNode child, Context context) { - Set joinLeftChildSymbols = ImmutableSet.copyOf(child.getLeft().getOutputSymbols()); - List groupingSet = getPushedDownGroupingSet(node, joinLeftChildSymbols, intersection(getJoinRequiredSymbols(child), joinLeftChildSymbols)); + Set joinLeftChildVariables = ImmutableSet.copyOf(child.getLeft().getOutputVariables()); + List groupingSet = getPushedDownGroupingSet(node, joinLeftChildVariables, intersection(getJoinRequiredVariables(child, context.getSymbolAllocator().getTypes()), joinLeftChildVariables)); AggregationNode pushedAggregation = replaceAggregationSource(node, child.getLeft(), groupingSet); return pushPartialToJoin(node, child, pushedAggregation, child.getRight(), context); } private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, Context context) { - Set joinRightChildSymbols = ImmutableSet.copyOf(child.getRight().getOutputSymbols()); - List groupingSet = getPushedDownGroupingSet(node, joinRightChildSymbols, intersection(getJoinRequiredSymbols(child), joinRightChildSymbols)); + Set joinRightChildVariables = ImmutableSet.copyOf(child.getRight().getOutputVariables()); + List groupingSet = getPushedDownGroupingSet(node, joinRightChildVariables, intersection(getJoinRequiredVariables(child, context.getSymbolAllocator().getTypes()), joinRightChildVariables)); AggregationNode pushedAggregation = replaceAggregationSource(node, child.getRight(), groupingSet); return pushPartialToJoin(node, child, child.getLeft(), pushedAggregation, context); } - private Set getJoinRequiredSymbols(JoinNode node) + private Set getJoinRequiredVariables(JoinNode node, TypeProvider types) { return Streams.concat( node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft), node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight), - node.getFilter().map(OriginalExpressionUtils::castToExpression).map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of()).stream(), - node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), - node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()) + node.getFilter().map(OriginalExpressionUtils::castToExpression).map(expression -> SymbolsExtractor.extractUniqueVariable(expression, types)).orElse(ImmutableSet.of()).stream(), + node.getLeftHashVariable().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), + node.getRightHashVariable().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()) .collect(toImmutableSet()); } - private List getPushedDownGroupingSet(AggregationNode aggregation, Set availableSymbols, Set requiredJoinSymbols) + private List getPushedDownGroupingSet(AggregationNode aggregation, Set availableVariables, Set requiredJoinVariables) { - List groupingSet = aggregation.getGroupingKeys(); + List groupingSet = aggregation.getGroupingKeys(); - // keep symbols that are directly from the join's child (availableSymbols) - List pushedDownGroupingSet = groupingSet.stream() - .filter(availableSymbols::contains) + // keep variables that are directly from the join's child (availableVariables) + List pushedDownGroupingSet = groupingSet.stream() + .filter(availableVariables::contains) .collect(Collectors.toList()); - // add missing required join symbols to grouping set - Set existingSymbols = new HashSet<>(pushedDownGroupingSet); - requiredJoinSymbols.stream() - .filter(existingSymbols::add) + // add missing required join variables to grouping set + Set existingVariables = new HashSet<>(pushedDownGroupingSet); + requiredJoinVariables.stream() + .filter(existingVariables::add) .forEach(pushedDownGroupingSet::add); return pushedDownGroupingSet; @@ -158,7 +161,7 @@ private List getPushedDownGroupingSet(AggregationNode aggregation, Set groupingKeys) + List groupingKeys) { return new AggregationNode( aggregation.getId(), @@ -167,8 +170,8 @@ private AggregationNode replaceAggregationSource( singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol()); + aggregation.getHashVariable(), + aggregation.getGroupIdVariable()); } private PlanNode pushPartialToJoin( @@ -184,14 +187,14 @@ private PlanNode pushPartialToJoin( leftChild, rightChild, child.getCriteria(), - ImmutableList.builder() - .addAll(leftChild.getOutputSymbols()) - .addAll(rightChild.getOutputSymbols()) + ImmutableList.builder() + .addAll(leftChild.getOutputVariables()) + .addAll(rightChild.getOutputVariables()) .build(), child.getFilter(), - child.getLeftHashSymbol(), - child.getRightHashSymbol(), + child.getLeftHashVariable(), + child.getRightHashVariable(), child.getDistributionType()); - return restrictOutputs(context.getIdAllocator(), joinNode, ImmutableSet.copyOf(aggregation.getOutputSymbols())).orElse(joinNode); + return restrictOutputs(context.getIdAllocator(), joinNode, ImmutableSet.copyOf(aggregation.getOutputVariables())).orElse(joinNode); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 600ac6ab241da..d413cf90cfa23 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -16,9 +16,11 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -35,8 +37,9 @@ import java.util.Set; import static com.facebook.presto.matching.Capture.newCapture; -import static com.facebook.presto.sql.planner.ExpressionSymbolInliner.inlineSymbols; +import static com.facebook.presto.sql.planner.ExpressionVariableInliner.inlineVariables; import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; +import static com.facebook.presto.sql.planner.optimizations.AddExchanges.toVariableReference; import static com.facebook.presto.sql.planner.plan.Patterns.exchange; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; @@ -82,65 +85,67 @@ public Pattern getPattern() public Result apply(ProjectNode project, Captures captures, Context context) { ExchangeNode exchange = captures.get(CHILD); - Set partitioningColumns = exchange.getPartitioningScheme().getPartitioning().getColumns(); + Set partitioningColumns = exchange.getPartitioningScheme().getPartitioning().getVariableReferences(); ImmutableList.Builder newSourceBuilder = ImmutableList.builder(); - ImmutableList.Builder> inputsBuilder = ImmutableList.builder(); + ImmutableList.Builder> inputsBuilder = ImmutableList.builder(); + TypeProvider types = context.getSymbolAllocator().getTypes(); for (int i = 0; i < exchange.getSources().size(); i++) { - Map outputToInputMap = extractExchangeOutputToInput(exchange, i); + Map outputToInputMap = extractExchangeOutputToInput(exchange, i); Assignments.Builder projections = Assignments.builder(); - ImmutableList.Builder inputs = ImmutableList.builder(); + ImmutableList.Builder inputs = ImmutableList.builder(); // Need to retain the partition keys for the exchange partitioningColumns.stream() .map(outputToInputMap::get) .forEach(nameReference -> { - Symbol symbol = Symbol.from(nameReference); - projections.put(symbol, nameReference); - inputs.add(symbol); + VariableReferenceExpression variable = toVariableReference(Symbol.from(nameReference), types); + projections.put(variable, nameReference); + inputs.add(variable); }); if (exchange.getPartitioningScheme().getHashColumn().isPresent()) { // Need to retain the hash symbol for the exchange - projections.put(exchange.getPartitioningScheme().getHashColumn().get(), exchange.getPartitioningScheme().getHashColumn().get().toSymbolReference()); - inputs.add(exchange.getPartitioningScheme().getHashColumn().get()); + VariableReferenceExpression hashVariable = exchange.getPartitioningScheme().getHashColumn().get(); + projections.put(hashVariable, new SymbolReference(hashVariable.getName())); + inputs.add(hashVariable); } if (exchange.getOrderingScheme().isPresent()) { // need to retain ordering columns for the exchange exchange.getOrderingScheme().get().getOrderBy().stream() // do not project the same symbol twice as ExchangeNode verifies that source input symbols match partitioning scheme outputLayout - .filter(symbol -> !partitioningColumns.contains(symbol)) + .filter(variable -> !partitioningColumns.contains(variable)) .map(outputToInputMap::get) .forEach(nameReference -> { - Symbol symbol = Symbol.from(nameReference); - projections.put(symbol, nameReference); - inputs.add(symbol); + VariableReferenceExpression variable = toVariableReference(Symbol.from(nameReference), types); + projections.put(variable, nameReference); + inputs.add(variable); }); } - for (Map.Entry projection : project.getAssignments().entrySet()) { - Expression translatedExpression = inlineSymbols(outputToInputMap, projection.getValue()); - Type type = context.getSymbolAllocator().getTypes().get(projection.getKey()); - Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type); - projections.put(symbol, translatedExpression); - inputs.add(symbol); + for (Map.Entry projection : project.getAssignments().entrySet()) { + Expression translatedExpression = inlineVariables(outputToInputMap, projection.getValue(), types); + Type type = projection.getKey().getType(); + VariableReferenceExpression variable = context.getSymbolAllocator().newVariable(translatedExpression, type); + projections.put(variable, translatedExpression); + inputs.add(variable); } newSourceBuilder.add(new ProjectNode(context.getIdAllocator().getNextId(), exchange.getSources().get(i), projections.build())); inputsBuilder.add(inputs.build()); } // Construct the output symbols in the same order as the sources - ImmutableList.Builder outputBuilder = ImmutableList.builder(); + ImmutableList.Builder outputBuilder = ImmutableList.builder(); partitioningColumns.forEach(outputBuilder::add); exchange.getPartitioningScheme().getHashColumn().ifPresent(outputBuilder::add); if (exchange.getOrderingScheme().isPresent()) { exchange.getOrderingScheme().get().getOrderBy().stream() - .filter(symbol -> !partitioningColumns.contains(symbol)) + .filter(variable -> !partitioningColumns.contains(variable)) .forEach(outputBuilder::add); } - for (Map.Entry projection : project.getAssignments().entrySet()) { + for (Map.Entry projection : project.getAssignments().entrySet()) { outputBuilder.add(projection.getKey()); } @@ -162,7 +167,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) exchange.getOrderingScheme()); // we need to strip unnecessary symbols (hash, partitioning columns). - return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(project.getOutputSymbols())).orElse(result)); + return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(project.getOutputVariables())).orElse(result)); } private static boolean isSymbolToSymbolProjection(ProjectNode project) @@ -170,11 +175,11 @@ private static boolean isSymbolToSymbolProjection(ProjectNode project) return project.getAssignments().getExpressions().stream().allMatch(e -> e instanceof SymbolReference); } - private static Map extractExchangeOutputToInput(ExchangeNode exchange, int sourceIndex) + private static Map extractExchangeOutputToInput(ExchangeNode exchange, int sourceIndex) { - Map outputToInputMap = new HashMap<>(); - for (int i = 0; i < exchange.getOutputSymbols().size(); i++) { - outputToInputMap.put(exchange.getOutputSymbols().get(i), exchange.getInputs().get(sourceIndex).get(i).toSymbolReference()); + Map outputToInputMap = new HashMap<>(); + for (int i = 0; i < exchange.getOutputVariables().size(); i++) { + outputToInputMap.put(exchange.getOutputVariables().get(i), new SymbolReference(exchange.getInputs().get(sourceIndex).get(i).getName())); } return outputToInputMap; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java index f58dcd21f1aa4..1ff8587c747bc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java @@ -16,6 +16,7 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; @@ -34,7 +35,7 @@ import java.util.Map; import static com.facebook.presto.matching.Capture.newCapture; -import static com.facebook.presto.sql.planner.ExpressionSymbolInliner.inlineSymbols; +import static com.facebook.presto.sql.planner.ExpressionVariableInliner.inlineVariables; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; import static com.facebook.presto.sql.planner.plan.Patterns.union; @@ -59,33 +60,38 @@ public Result apply(ProjectNode parent, Captures captures, Context context) UnionNode source = captures.get(CHILD); // OutputLayout of the resultant Union, will be same as the layout of the Project - List outputLayout = parent.getOutputSymbols(); + List outputLayout = parent.getOutputVariables(); // Mapping from the output symbol to ordered list of symbols from each of the sources - ImmutableListMultimap.Builder mappings = ImmutableListMultimap.builder(); + ImmutableListMultimap.Builder mappings = ImmutableListMultimap.builder(); // sources for the resultant UnionNode ImmutableList.Builder outputSources = ImmutableList.builder(); for (int i = 0; i < source.getSources().size(); i++) { - Map outputToInput = Maps.transformValues(source.sourceSymbolMap(i), Symbol::toSymbolReference); // Map: output of union -> input of this source to the union + Map outputToInput = Maps.transformValues(source.sourceVariableMap(i), variable -> new SymbolReference(variable.getName())); // Map: output of union -> input of this source to the union Assignments.Builder assignments = Assignments.builder(); // assignments for the new ProjectNode // mapping from current ProjectNode to new ProjectNode, used to identify the output layout - Map projectSymbolMapping = new HashMap<>(); + Map projectVariableMapping = new HashMap<>(); // Translate the assignments in the ProjectNode using symbols of the source of the UnionNode - for (Map.Entry entry : parent.getAssignments().entrySet()) { - Expression translatedExpression = inlineSymbols(outputToInput, entry.getValue()); - Type type = context.getSymbolAllocator().getTypes().get(entry.getKey()); - Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type); - assignments.put(symbol, translatedExpression); - projectSymbolMapping.put(entry.getKey(), symbol); + for (Map.Entry entry : parent.getAssignments().entrySet()) { + Expression translatedExpression = inlineVariables(outputToInput, entry.getValue(), context.getSymbolAllocator().getTypes()); + Type type = entry.getKey().getType(); + VariableReferenceExpression variable = context.getSymbolAllocator().newVariable(translatedExpression, type); + assignments.put(variable, translatedExpression); + projectVariableMapping.put(new VariableReferenceExpression(entry.getKey().getName(), type), variable); } outputSources.add(new ProjectNode(context.getIdAllocator().getNextId(), source.getSources().get(i), assignments.build())); - outputLayout.forEach(symbol -> mappings.put(symbol, projectSymbolMapping.get(symbol))); + outputLayout.forEach(variable -> mappings.put(variable, projectVariableMapping.get(variable))); } - return Result.ofPlanNode(new UnionNode(parent.getId(), outputSources.build(), mappings.build(), ImmutableList.copyOf(mappings.build().keySet()))); + return Result.ofPlanNode(new UnionNode(parent.getId(), outputSources.build(), mappings.build())); + } + + private static VariableReferenceExpression getWithMatchingSymbol(Map variableMapping, Symbol symbol) + { + return variableMapping.entrySet().stream().filter(entry -> entry.getKey().getName().equals(symbol.getName())).findAny().map(Map.Entry::getValue).get(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java index f0f27a70572f1..463ee1c4dae05 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java @@ -16,8 +16,8 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.PartitioningScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -62,7 +62,7 @@ public Result apply(ExchangeNode node, Captures captures, Context context) AssignUniqueId assignUniqueId = captures.get(ASSIGN_UNIQUE_ID); PartitioningScheme partitioningScheme = node.getPartitioningScheme(); - if (partitioningScheme.getPartitioning().getColumns().contains(assignUniqueId.getIdColumn())) { + if (partitioningScheme.getPartitioning().getVariableReferences().contains(assignUniqueId.getIdVariable())) { // The column produced by the AssignUniqueId is used in the partitioning scheme of the exchange. // Hence, AssignUniqueId node has to stay below the exchange node. return Result.empty(); @@ -76,20 +76,20 @@ public Result apply(ExchangeNode node, Captures captures, Context context) node.getScope(), new PartitioningScheme( partitioningScheme.getPartitioning(), - removeSymbol(partitioningScheme.getOutputLayout(), assignUniqueId.getIdColumn()), + removeVariable(partitioningScheme.getOutputLayout(), assignUniqueId.getIdVariable()), partitioningScheme.getHashColumn(), partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition()), ImmutableList.of(assignUniqueId.getSource()), - ImmutableList.of(removeSymbol(getOnlyElement(node.getInputs()), assignUniqueId.getIdColumn())), + ImmutableList.of(removeVariable(getOnlyElement(node.getInputs()), assignUniqueId.getIdVariable())), Optional.empty()), - assignUniqueId.getIdColumn())); + assignUniqueId.getIdVariable())); } - private static List removeSymbol(List symbols, Symbol symbolToRemove) + private static List removeVariable(List variables, VariableReferenceExpression variableToRemove) { - return symbols.stream() - .filter(symbol -> !symbolToRemove.equals(symbol)) + return variables.stream() + .filter(variable -> !variableToRemove.equals(variable)) .collect(toImmutableList()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTableWriteThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTableWriteThroughUnion.java index c10a7cf17b03a..8fbc5a54dd458 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTableWriteThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTableWriteThroughUnion.java @@ -17,7 +17,7 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.SymbolMapper; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -69,41 +69,40 @@ public Result apply(TableWriterNode writerNode, Captures captures, Context conte { UnionNode unionNode = captures.get(CHILD); ImmutableList.Builder rewrittenSources = ImmutableList.builder(); - List> sourceMappings = new ArrayList<>(); + List> sourceMappings = new ArrayList<>(); for (int source = 0; source < unionNode.getSources().size(); source++) { rewrittenSources.add(rewriteSource(writerNode, unionNode, source, sourceMappings, context)); } - ImmutableListMultimap.Builder unionMappings = ImmutableListMultimap.builder(); + ImmutableListMultimap.Builder unionMappings = ImmutableListMultimap.builder(); sourceMappings.forEach(mappings -> mappings.forEach(unionMappings::put)); return Result.ofPlanNode( new UnionNode( context.getIdAllocator().getNextId(), rewrittenSources.build(), - unionMappings.build(), - ImmutableList.copyOf(unionMappings.build().keySet()))); + unionMappings.build())); } private static TableWriterNode rewriteSource( TableWriterNode writerNode, UnionNode unionNode, int source, - List> sourceMappings, + List> sourceMappings, Context context) { - Map inputMappings = getInputSymbolMapping(unionNode, source); - ImmutableMap.Builder mappings = ImmutableMap.builder(); + Map inputMappings = getInputVariableMapping(unionNode, source); + ImmutableMap.Builder mappings = ImmutableMap.builder(); mappings.putAll(inputMappings); - ImmutableMap.Builder outputMappings = ImmutableMap.builder(); - for (Symbol outputSymbol : writerNode.getOutputSymbols()) { - if (inputMappings.containsKey(outputSymbol)) { - outputMappings.put(outputSymbol, inputMappings.get(outputSymbol)); + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + for (VariableReferenceExpression outputVariable : writerNode.getOutputVariables()) { + if (inputMappings.containsKey(outputVariable)) { + outputMappings.put(outputVariable, inputMappings.get(outputVariable)); } else { - Symbol newSymbol = context.getSymbolAllocator().newSymbol(outputSymbol); - outputMappings.put(outputSymbol, newSymbol); - mappings.put(outputSymbol, newSymbol); + VariableReferenceExpression newVariable = context.getSymbolAllocator().newVariable(outputVariable); + outputMappings.put(outputVariable, newVariable); + mappings.put(outputVariable, newVariable); } } sourceMappings.add(outputMappings.build()); @@ -111,11 +110,11 @@ private static TableWriterNode rewriteSource( return symbolMapper.map(writerNode, unionNode.getSources().get(source), context.getIdAllocator().getNextId()); } - private static Map getInputSymbolMapping(UnionNode node, int source) + private static Map getInputVariableMapping(UnionNode node, int source) { - return node.getSymbolMapping() + return node.getVariableMapping() .keySet() .stream() - .collect(toImmutableMap(key -> key, key -> node.getSymbolMapping().get(key).get(source))); + .collect(toImmutableMap(key -> key, key -> node.getVariableMapping().get(key).get(source))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java index 64682cb9e6402..cf89d9165c588 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java @@ -16,7 +16,7 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.SymbolMapper; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -60,11 +60,12 @@ public Result apply(TopNNode topNNode, Captures captures, Context context) for (PlanNode source : unionNode.getSources()) { SymbolMapper.Builder symbolMapper = SymbolMapper.builder(); - Set sourceOutputSymbols = ImmutableSet.copyOf(source.getOutputSymbols()); - for (Symbol unionOutput : unionNode.getOutputSymbols()) { - Set inputSymbols = ImmutableSet.copyOf(unionNode.getSymbolMapping().get(unionOutput)); - Symbol unionInput = getLast(intersection(inputSymbols, sourceOutputSymbols)); + Set sourceOutputVariables = ImmutableSet.copyOf(source.getOutputVariables()); + + for (VariableReferenceExpression unionOutput : unionNode.getOutputVariables()) { + Set inputVariables = ImmutableSet.copyOf(unionNode.getVariableMapping().get(unionOutput)); + VariableReferenceExpression unionInput = getLast(intersection(inputVariables, sourceOutputVariables)); symbolMapper.put(unionOutput, unionInput); } sources.add(symbolMapper.build().map(topNNode, source, context.getIdAllocator().getNextId())); @@ -73,7 +74,6 @@ public Result apply(TopNNode topNNode, Captures captures, Context context) return Result.ofPlanNode(new UnionNode( unionNode.getId(), sources.build(), - unionNode.getSymbolMapping(), - unionNode.getOutputSymbols())); + unionNode.getVariableMapping())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java index c2d137d8fc38f..a15f6134182c4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java @@ -74,7 +74,7 @@ public Result apply(TableFinishNode node, Captures captures, Context context) return Result.ofPlanNode( new ValuesNode( node.getId(), - node.getOutputSymbols(), + node.getOutputVariables(), ImmutableList.of(ImmutableList.of(constant(0L, BIGINT))))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java index 3f67870605a53..c51781b9b6cbd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java @@ -36,7 +36,7 @@ public class RemoveRedundantIdentityProjections private static boolean outputsSameAsSource(ProjectNode node) { - return ImmutableSet.copyOf(node.getOutputSymbols()).equals(ImmutableSet.copyOf(node.getSource().getOutputSymbols())); + return ImmutableSet.copyOf(node.getOutputVariables()).equals(ImmutableSet.copyOf(node.getSource().getOutputVariables())); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveTrivialFilters.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveTrivialFilters.java index da75d4335ec52..fb26966b3eda3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveTrivialFilters.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveTrivialFilters.java @@ -47,7 +47,7 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) } if (predicate.equals(FALSE_LITERAL)) { - return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), filterNode.getOutputSymbols(), ImmutableList.of())); + return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), filterNode.getOutputVariables(), ImmutableList.of())); } return Result.empty(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java index 9ad92f3f41184..bcb3a9fe07468 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java @@ -53,6 +53,6 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context private boolean isUnreferencedScalar(PlanNode planNode, Lookup lookup) { - return planNode.getOutputSymbols().isEmpty() && isScalar(planNode, lookup); + return planNode.getOutputVariables().isEmpty() && isScalar(planNode, lookup); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java index 3cbf9f88472a3..87e01710cbc80 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java @@ -21,9 +21,11 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.EqualityInference; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; @@ -129,7 +131,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) costComparator, multiJoinNode.getFilter(), context); - JoinEnumerationResult result = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols()); + JoinEnumerationResult result = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputVariables()); if (!result.getPlanNode().isPresent()) { return Result.empty(); } @@ -164,7 +166,7 @@ static class JoinEnumerator this.lookup = requireNonNull(context.getLookup(), "lookup is null"); } - private JoinEnumerationResult chooseJoinOrder(LinkedHashSet sources, List outputSymbols) + private JoinEnumerationResult chooseJoinOrder(LinkedHashSet sources, List outputVariables) { context.checkTimeoutNotExhausted(); @@ -175,7 +177,7 @@ private JoinEnumerationResult chooseJoinOrder(LinkedHashSet sources, L ImmutableList.Builder resultBuilder = ImmutableList.builder(); Set> partitions = generatePartitions(sources.size()); for (Set partition : partitions) { - JoinEnumerationResult result = createJoinAccordingToPartitioning(sources, outputSymbols, partition); + JoinEnumerationResult result = createJoinAccordingToPartitioning(sources, outputVariables, partition); if (result.equals(UNKNOWN_COST_RESULT)) { memo.put(multiJoinKey, result); return result; @@ -222,7 +224,7 @@ static Set> generatePartitions(int totalNodes) } @VisibleForTesting - JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet sources, List outputSymbols, Set partitioning) + JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet sources, List outputVariables, Set partitioning) { List sourceList = ImmutableList.copyOf(sources); LinkedHashSet leftSources = partitioning.stream() @@ -231,22 +233,22 @@ JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet LinkedHashSet rightSources = sources.stream() .filter(source -> !leftSources.contains(source)) .collect(toCollection(LinkedHashSet::new)); - return createJoin(leftSources, rightSources, outputSymbols); + return createJoin(leftSources, rightSources, outputVariables); } - private JoinEnumerationResult createJoin(LinkedHashSet leftSources, LinkedHashSet rightSources, List outputSymbols) + private JoinEnumerationResult createJoin(LinkedHashSet leftSources, LinkedHashSet rightSources, List outputVariables) { - Set leftSymbols = leftSources.stream() - .flatMap(node -> node.getOutputSymbols().stream()) + Set leftVariables = leftSources.stream() + .flatMap(node -> node.getOutputVariables().stream()) .collect(toImmutableSet()); - Set rightSymbols = rightSources.stream() - .flatMap(node -> node.getOutputSymbols().stream()) + Set rightVariables = rightSources.stream() + .flatMap(node -> node.getOutputVariables().stream()) .collect(toImmutableSet()); - List joinPredicates = getJoinPredicates(leftSymbols, rightSymbols); + List joinPredicates = getJoinPredicates(leftVariables, rightVariables); List joinConditions = joinPredicates.stream() .filter(JoinEnumerator::isJoinEqualityCondition) - .map(predicate -> toEquiJoinClause((ComparisonExpression) predicate, leftSymbols)) + .map(predicate -> toEquiJoinClause((ComparisonExpression) predicate, leftVariables, context.getSymbolAllocator())) .collect(toImmutableList()); if (joinConditions.isEmpty()) { return INFINITE_COST_RESULT; @@ -255,15 +257,15 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li .filter(predicate -> !isJoinEqualityCondition(predicate)) .collect(toImmutableList()); - Set requiredJoinSymbols = ImmutableSet.builder() - .addAll(outputSymbols) - .addAll(SymbolsExtractor.extractUnique(joinPredicates)) + Set requiredJoinVariables = ImmutableSet.builder() + .addAll(outputVariables) + .addAll(SymbolsExtractor.extractUniqueVariable(joinPredicates, context.getSymbolAllocator().getTypes())) .build(); JoinEnumerationResult leftResult = getJoinSource( leftSources, - requiredJoinSymbols.stream() - .filter(leftSymbols::contains) + requiredJoinVariables.stream() + .filter(leftVariables::contains) .collect(toImmutableList())); if (leftResult.equals(UNKNOWN_COST_RESULT)) { return UNKNOWN_COST_RESULT; @@ -276,8 +278,8 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li JoinEnumerationResult rightResult = getJoinSource( rightSources, - requiredJoinSymbols.stream() - .filter(rightSymbols::contains) + requiredJoinVariables.stream() + .filter(rightVariables::contains) .collect(toImmutableList())); if (rightResult.equals(UNKNOWN_COST_RESULT)) { return UNKNOWN_COST_RESULT; @@ -289,8 +291,8 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li PlanNode right = rightResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present")); // sort output symbols so that the left input symbols are first - List sortedOutputSymbols = Stream.concat(left.getOutputSymbols().stream(), right.getOutputSymbols().stream()) - .filter(outputSymbols::contains) + List sortedOutputVariables = Stream.concat(left.getOutputVariables().stream(), right.getOutputVariables().stream()) + .filter(outputVariables::contains) .collect(toImmutableList()); return setJoinNodeProperties(new JoinNode( @@ -299,45 +301,52 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li left, right, joinConditions, - sortedOutputSymbols, + sortedOutputVariables, joinFilters.isEmpty() ? Optional.empty() : Optional.of(castToRowExpression(and(joinFilters))), Optional.empty(), Optional.empty(), Optional.empty())); } - private List getJoinPredicates(Set leftSymbols, Set rightSymbols) + private List getJoinPredicates(Set leftVariables, Set rightVariables) { ImmutableList.Builder joinPredicatesBuilder = ImmutableList.builder(); - + List leftSymbols = leftVariables.stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableList()); + List rightSymbols = rightVariables.stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableList()); // This takes all conjuncts that were part of allFilters that // could not be used for equality inference. // If they use both the left and right symbols, we add them to the list of joinPredicates stream(nonInferrableConjuncts(allFilter)) - .map(conjunct -> allFilterInference.rewriteExpression(conjunct, symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol))) + .map(conjunct -> allFilterInference.rewriteExpression( + conjunct, + variable -> leftVariables.contains(variable) || rightVariables.contains(variable), + context.getSymbolAllocator().getTypes())) .filter(Objects::nonNull) // filter expressions that contain only left or right symbols - .filter(conjunct -> allFilterInference.rewriteExpression(conjunct, leftSymbols::contains) == null) - .filter(conjunct -> allFilterInference.rewriteExpression(conjunct, rightSymbols::contains) == null) + .filter(conjunct -> allFilterInference.rewriteExpression(conjunct, leftVariables::contains, context.getSymbolAllocator().getTypes()) == null) + .filter(conjunct -> allFilterInference.rewriteExpression(conjunct, rightVariables::contains, context.getSymbolAllocator().getTypes()) == null) .forEach(joinPredicatesBuilder::add); // create equality inference on available symbols // TODO: make generateEqualitiesPartitionedBy take left and right scope - List joinEqualities = allFilterInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)).getScopeEqualities(); + List joinEqualities = allFilterInference.generateEqualitiesPartitionedBy( + variable -> leftVariables.contains(variable) || rightVariables.contains(variable), + context.getSymbolAllocator().getTypes()).getScopeEqualities(); EqualityInference joinInference = createEqualityInference(joinEqualities.toArray(new Expression[0])); - joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeStraddlingEqualities()); + joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftVariables), context.getSymbolAllocator().getTypes()).getScopeStraddlingEqualities()); return joinPredicatesBuilder.build(); } - private JoinEnumerationResult getJoinSource(LinkedHashSet nodes, List outputSymbols) + private JoinEnumerationResult getJoinSource(LinkedHashSet nodes, List outputVariables) { + List outputSymbols = outputVariables.stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableList()); if (nodes.size() == 1) { PlanNode planNode = getOnlyElement(nodes); ImmutableList.Builder predicates = ImmutableList.builder(); - predicates.addAll(allFilterInference.generateEqualitiesPartitionedBy(outputSymbols::contains).getScopeEqualities()); + predicates.addAll(allFilterInference.generateEqualitiesPartitionedBy(outputVariables::contains, context.getSymbolAllocator().getTypes()).getScopeEqualities()); stream(nonInferrableConjuncts(allFilter)) - .map(conjunct -> allFilterInference.rewriteExpression(conjunct, outputSymbols::contains)) + .map(conjunct -> allFilterInference.rewriteExpression(conjunct, outputVariables::contains, context.getSymbolAllocator().getTypes())) .filter(Objects::nonNull) .forEach(predicates::add); Expression filter = combineConjuncts(predicates.build()); @@ -346,7 +355,7 @@ private JoinEnumerationResult getJoinSource(LinkedHashSet nodes, List< } return createJoinEnumerationResult(planNode); } - return chooseJoinOrder(nodes, outputSymbols); + return chooseJoinOrder(nodes, outputVariables); } private static boolean isJoinEqualityCondition(Expression expression) @@ -357,12 +366,12 @@ private static boolean isJoinEqualityCondition(Expression expression) && ((ComparisonExpression) expression).getRight() instanceof SymbolReference; } - private static EquiJoinClause toEquiJoinClause(ComparisonExpression equality, Set leftSymbols) + private static EquiJoinClause toEquiJoinClause(ComparisonExpression equality, Set leftVariables, SymbolAllocator symbolAllocator) { - Symbol leftSymbol = Symbol.from(equality.getLeft()); - Symbol rightSymbol = Symbol.from(equality.getRight()); - EquiJoinClause equiJoinClause = new EquiJoinClause(leftSymbol, rightSymbol); - return leftSymbols.contains(leftSymbol) ? equiJoinClause : equiJoinClause.flip(); + VariableReferenceExpression leftVariable = symbolAllocator.toVariableReference(Symbol.from(equality.getLeft())); + VariableReferenceExpression rightVariable = symbolAllocator.toVariableReference(Symbol.from(equality.getRight())); + EquiJoinClause equiJoinClause = new EquiJoinClause(leftVariable, rightVariable); + return leftVariables.contains(leftVariable) ? equiJoinClause : equiJoinClause.flip(); } private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) @@ -429,21 +438,21 @@ static class MultiJoinNode // Use a linked hash set to ensure optimizer is deterministic private final LinkedHashSet sources; private final Expression filter; - private final List outputSymbols; + private final List outputVariables; - public MultiJoinNode(LinkedHashSet sources, Expression filter, List outputSymbols) + public MultiJoinNode(LinkedHashSet sources, Expression filter, List outputVariables) { requireNonNull(sources, "sources is null"); checkArgument(sources.size() > 1, "sources size is <= 1"); requireNonNull(filter, "filter is null"); - requireNonNull(outputSymbols, "outputSymbols is null"); + requireNonNull(outputVariables, "outputVariables is null"); this.sources = sources; this.filter = filter; - this.outputSymbols = ImmutableList.copyOf(outputSymbols); + this.outputVariables = ImmutableList.copyOf(outputVariables); - List inputSymbols = sources.stream().flatMap(source -> source.getOutputSymbols().stream()).collect(toImmutableList()); - checkArgument(inputSymbols.containsAll(outputSymbols), "inputs do not contain all output symbols"); + List inputVariables = sources.stream().flatMap(source -> source.getOutputVariables().stream()).collect(toImmutableList()); + checkArgument(inputVariables.containsAll(outputVariables), "inputs do not contain all output variables"); } public Expression getFilter() @@ -456,9 +465,9 @@ public LinkedHashSet getSources() return sources; } - public List getOutputSymbols() + public List getOutputVariables() { - return outputSymbols; + return outputVariables; } public static Builder builder() @@ -469,7 +478,7 @@ public static Builder builder() @Override public int hashCode() { - return Objects.hash(sources, ImmutableSet.copyOf(extractConjuncts(filter)), outputSymbols); + return Objects.hash(sources, ImmutableSet.copyOf(extractConjuncts(filter)), outputVariables); } @Override @@ -482,7 +491,7 @@ public boolean equals(Object obj) MultiJoinNode other = (MultiJoinNode) obj; return this.sources.equals(other.sources) && ImmutableSet.copyOf(extractConjuncts(this.filter)).equals(ImmutableSet.copyOf(extractConjuncts(other.filter))) - && this.outputSymbols.equals(other.outputSymbols); + && this.outputVariables.equals(other.outputVariables); } static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int joinLimit) @@ -495,14 +504,14 @@ private static class JoinNodeFlattener { private final LinkedHashSet sources = new LinkedHashSet<>(); private final List filters = new ArrayList<>(); - private final List outputSymbols; + private final List outputVariables; private final Lookup lookup; JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit) { requireNonNull(node, "node is null"); checkState(node.getType() == INNER, "join type must be INNER"); - this.outputSymbols = node.getOutputSymbols(); + this.outputVariables = node.getOutputVariables(); this.lookup = requireNonNull(lookup, "lookup is null"); flattenNode(node, sourceLimit); } @@ -534,7 +543,7 @@ private void flattenNode(PlanNode node, int limit) MultiJoinNode toMultiJoinNode() { - return new MultiJoinNode(sources, and(filters), outputSymbols); + return new MultiJoinNode(sources, and(filters), outputVariables); } } @@ -542,7 +551,7 @@ static class Builder { private List sources; private Expression filter; - private List outputSymbols; + private List outputVariables; public Builder setSources(PlanNode... sources) { @@ -556,15 +565,15 @@ public Builder setFilter(Expression filter) return this; } - public Builder setOutputSymbols(Symbol... outputSymbols) + public Builder setOutputVariables(VariableReferenceExpression... outputVariables) { - this.outputSymbols = ImmutableList.copyOf(outputSymbols); + this.outputVariables = ImmutableList.copyOf(outputVariables); return this; } public MultiJoinNode build() { - return new MultiJoinNode(new LinkedHashSet<>(sources), filter, outputSymbols); + return new MultiJoinNode(new LinkedHashSet<>(sources), filter, outputVariables); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index f9a6a8b81e508..f5766e05045cc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -16,9 +16,9 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -28,6 +28,7 @@ import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -87,26 +88,26 @@ public Pattern getPattern() @Override public Result apply(AggregationNode node, Captures captures, Context context) { - ImmutableMap.Builder aggregations = ImmutableMap.builder(); - Symbol partitionCountSymbol = context.getSymbolAllocator().newSymbol("partition_count", INTEGER); - ImmutableMap.Builder envelopeAssignments = ImmutableMap.builder(); - for (Map.Entry entry : node.getAggregations().entrySet()) { + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + VariableReferenceExpression partitionCountVariable = context.getSymbolAllocator().newVariable("partition_count", INTEGER); + ImmutableMap.Builder envelopeAssignments = ImmutableMap.builder(); + for (Map.Entry entry : node.getAggregations().entrySet()) { Aggregation aggregation = entry.getValue(); String name = metadata.getFunctionManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName(); Type geometryType = metadata.getType(GEOMETRY_TYPE_SIGNATURE); if (name.equals(NAME) && aggregation.getArguments().size() == 1) { Expression geometry = getOnlyElement(aggregation.getArguments()); - Symbol envelopeSymbol = context.getSymbolAllocator().newSymbol("envelope", geometryType); + VariableReferenceExpression envelopeVariable = context.getSymbolAllocator().newVariable("envelope", geometryType); if (geometry instanceof FunctionCall && ((FunctionCall) geometry).getName().toString().equalsIgnoreCase("ST_Envelope")) { - envelopeAssignments.put(envelopeSymbol, geometry); + envelopeAssignments.put(envelopeVariable, geometry); } else { - envelopeAssignments.put(envelopeSymbol, new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(geometry))); + envelopeAssignments.put(envelopeVariable, new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(geometry))); } aggregations.put(entry.getKey(), new Aggregation( metadata.getFunctionManager().lookupFunction(NAME, fromTypes(geometryType, INTEGER)), - ImmutableList.of(envelopeSymbol.toSymbolReference(), partitionCountSymbol.toSymbolReference()), + ImmutableList.of(new SymbolReference(envelopeVariable.getName()), new SymbolReference(partitionCountVariable.getName())), Optional.empty(), Optional.empty(), false, @@ -124,15 +125,15 @@ public Result apply(AggregationNode node, Captures captures, Context context) context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder() - .putIdentities(node.getSource().getOutputSymbols()) - .put(partitionCountSymbol, new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession())))) + .putIdentities(node.getSource().getOutputVariables()) + .put(partitionCountVariable, new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession())))) .putAll(envelopeAssignments.build()) .build()), aggregations.build(), node.getGroupingSets(), - node.getPreGroupedSymbols(), + node.getPreGroupedVariables(), node.getStep(), - node.getHashSymbol(), - node.getGroupIdSymbol())); + node.getHashVariable(), + node.getGroupIdVariable())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 39bb26089c132..34094f705b4cd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -18,6 +18,7 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -69,15 +70,15 @@ public Result apply(AggregationNode parent, Captures captures, Context context) ProjectNode child = captures.get(CHILD); boolean changed = false; - Map aggregations = new LinkedHashMap<>(parent.getAggregations()); + Map aggregations = new LinkedHashMap<>(parent.getAggregations()); - for (Entry entry : parent.getAggregations().entrySet()) { - Symbol symbol = entry.getKey(); + for (Entry entry : parent.getAggregations().entrySet()) { + VariableReferenceExpression variable = entry.getKey(); AggregationNode.Aggregation aggregation = entry.getValue(); if (isCountOverConstant(aggregation, child.getAssignments())) { changed = true; - aggregations.put(symbol, new AggregationNode.Aggregation( + aggregations.put(variable, new AggregationNode.Aggregation( functionResolution.countFunction(), ImmutableList.of(), Optional.empty(), @@ -98,8 +99,8 @@ public Result apply(AggregationNode parent, Captures captures, Context context) parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), - parent.getHashSymbol(), - parent.getGroupIdSymbol())); + parent.getHashVariable(), + parent.getGroupIdVariable())); } private boolean isCountOverConstant(AggregationNode.Aggregation aggregation, Assignments inputs) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java index 2078bb0b7e593..ab47ad86898ec 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java @@ -15,6 +15,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -118,8 +119,9 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont List> argumentSets = extractArgumentSets(aggregation) .collect(Collectors.toList()); - Set symbols = Iterables.getOnlyElement(argumentSets).stream() + Set variables = Iterables.getOnlyElement(argumentSets).stream() .map(Symbol::from) + .map(context.getSymbolAllocator()::toVariableReference) .collect(Collectors.toSet()); return Result.ofPlanNode( @@ -129,9 +131,9 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont context.getIdAllocator().getNextId(), aggregation.getSource(), ImmutableMap.of(), - singleGroupingSet(ImmutableList.builder() + singleGroupingSet(ImmutableList.builder() .addAll(aggregation.getGroupingKeys()) - .addAll(symbols) + .addAll(variables) .build()), ImmutableList.of(), SINGLE, @@ -146,8 +148,8 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont aggregation.getGroupingSets(), emptyList(), aggregation.getStep(), - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol())); + aggregation.getHashVariable(), + aggregation.getGroupIdVariable())); } private static AggregationNode.Aggregation removeDistinct(AggregationNode.Aggregation aggregation) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index ffb8604054872..240f0b5d49717 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -18,9 +18,11 @@ import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -125,20 +127,20 @@ public Result apply(ApplyNode apply, Captures captures, Context context) } InPredicate inPredicate = (InPredicate) assignmentExpression; - Symbol inPredicateOutputSymbol = getOnlyElement(subqueryAssignments.getSymbols()); + VariableReferenceExpression inPredicateOutputVariable = getOnlyElement(subqueryAssignments.getVariables()); - return apply(apply, inPredicate, inPredicateOutputSymbol, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator()); + return apply(apply, inPredicate, inPredicateOutputVariable, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator()); } private Result apply( ApplyNode apply, InPredicate inPredicate, - Symbol inPredicateOutputSymbol, + VariableReferenceExpression inPredicateOutputVariable, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) { - Optional decorrelated = new DecorrelatingVisitor(lookup, apply.getCorrelation()) + Optional decorrelated = new DecorrelatingVisitor(lookup, apply.getCorrelation(), symbolAllocator.getTypes()) .decorrelate(apply.getSubquery()); if (!decorrelated.isPresent()) { @@ -148,7 +150,7 @@ private Result apply( PlanNode projection = buildInPredicateEquivalent( apply, inPredicate, - inPredicateOutputSymbol, + inPredicateOutputVariable, decorrelated.get(), idAllocator, symbolAllocator); @@ -159,7 +161,7 @@ private Result apply( private PlanNode buildInPredicateEquivalent( ApplyNode apply, InPredicate inPredicate, - Symbol inPredicateOutputSymbol, + VariableReferenceExpression inPredicateOutputVariable, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) @@ -170,48 +172,48 @@ private PlanNode buildInPredicateEquivalent( AssignUniqueId probeSide = new AssignUniqueId( idAllocator.getNextId(), apply.getInput(), - symbolAllocator.newSymbol("unique", BIGINT)); + symbolAllocator.newVariable("unique", BIGINT)); - Symbol buildSideKnownNonNull = symbolAllocator.newSymbol("buildSideKnownNonNull", BIGINT); + VariableReferenceExpression buildSideKnownNonNull = symbolAllocator.newVariable("buildSideKnownNonNull", BIGINT); ProjectNode buildSide = new ProjectNode( idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder() - .putIdentities(decorrelatedBuildSource.getOutputSymbols()) + .putIdentities(decorrelatedBuildSource.getOutputVariables()) .put(buildSideKnownNonNull, bigint(0)) .build()); - Symbol probeSideSymbol = Symbol.from(inPredicate.getValue()); - Symbol buildSideSymbol = Symbol.from(inPredicate.getValueList()); + SymbolReference probeSideSymbolReference = Symbol.from(inPredicate.getValue()).toSymbolReference(); + SymbolReference buildSideSymbolReference = Symbol.from(inPredicate.getValueList()).toSymbolReference(); Expression joinExpression = and( or( - new IsNullPredicate(probeSideSymbol.toSymbolReference()), - new ComparisonExpression(ComparisonExpression.Operator.EQUAL, probeSideSymbol.toSymbolReference(), buildSideSymbol.toSymbolReference()), - new IsNullPredicate(buildSideSymbol.toSymbolReference())), + new IsNullPredicate(probeSideSymbolReference), + new ComparisonExpression(ComparisonExpression.Operator.EQUAL, probeSideSymbolReference, buildSideSymbolReference), + new IsNullPredicate(buildSideSymbolReference)), correlationCondition); JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression); - Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", BIGINT); - Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", BIGINT); + VariableReferenceExpression countMatchesVariable = symbolAllocator.newVariable("countMatches", BIGINT); + VariableReferenceExpression countNullMatchesVariable = symbolAllocator.newVariable("countNullMatches", BIGINT); Expression matchCondition = and( - isNotNull(probeSideSymbol), - isNotNull(buildSideSymbol)); + new IsNotNullPredicate(probeSideSymbolReference), + new IsNotNullPredicate(buildSideSymbolReference)); Expression nullMatchCondition = and( - isNotNull(buildSideKnownNonNull), - not(matchCondition)); + new IsNotNullPredicate(new SymbolReference(buildSideKnownNonNull.getName())), + new NotExpression(matchCondition)); AggregationNode aggregation = new AggregationNode( idAllocator.getNextId(), leftOuterJoin, - ImmutableMap.builder() - .put(countMatchesSymbol, countWithFilter(matchCondition)) - .put(countNullMatchesSymbol, countWithFilter(nullMatchCondition)) + ImmutableMap.builder() + .put(countMatchesVariable, countWithFilter(matchCondition)) + .put(countNullMatchesVariable, countWithFilter(nullMatchCondition)) .build(), - singleGroupingSet(probeSide.getOutputSymbols()), + singleGroupingSet(probeSide.getOutputVariables()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), @@ -220,15 +222,15 @@ private PlanNode buildInPredicateEquivalent( // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression( ImmutableList.of( - new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), - new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), + new WhenClause(isGreaterThan(countMatchesVariable, 0), booleanConstant(true)), + new WhenClause(isGreaterThan(countNullMatchesVariable, 0), booleanConstant(null))), Optional.of(booleanConstant(false))); return new ProjectNode( idAllocator.getNextId(), aggregation, Assignments.builder() - .putIdentities(apply.getInput().getOutputSymbols()) - .put(inPredicateOutputSymbol, inPredicateEquivalent) + .putIdentities(apply.getInput().getOutputVariables()) + .put(inPredicateOutputVariable, inPredicateEquivalent) .build()); } @@ -240,9 +242,9 @@ private static JoinNode leftOuterJoin(PlanNodeIdAllocator idAllocator, AssignUni probeSide, buildSide, ImmutableList.of(), - ImmutableList.builder() - .addAll(probeSide.getOutputSymbols()) - .addAll(buildSide.getOutputSymbols()) + ImmutableList.builder() + .addAll(probeSide.getOutputVariables()) + .addAll(buildSide.getOutputVariables()) .build(), Optional.of(castToRowExpression(joinExpression)), Optional.empty(), @@ -258,27 +260,17 @@ private AggregationNode.Aggregation countWithFilter(Expression condition) Optional.of(condition), Optional.empty(), false, - Optional.empty()); /* mask */ + Optional.empty()); /* mask */ } - private static Expression isGreaterThan(Symbol symbol, long value) + private static Expression isGreaterThan(VariableReferenceExpression variable, long value) { return new ComparisonExpression( ComparisonExpression.Operator.GREATER_THAN, - symbol.toSymbolReference(), + new SymbolReference(variable.getName()), bigint(value)); } - private static Expression not(Expression booleanExpression) - { - return new NotExpression(booleanExpression); - } - - private static Expression isNotNull(Symbol symbol) - { - return new IsNotNullPredicate(symbol.toSymbolReference()); - } - private static Expression bigint(long value) { return new Cast(new LongLiteral(String.valueOf(value)), BIGINT.toString()); @@ -296,12 +288,14 @@ private static class DecorrelatingVisitor extends InternalPlanVisitor, PlanNode> { private final Lookup lookup; - private final Set correlation; + private final Set correlation; + private final TypeProvider types; - public DecorrelatingVisitor(Lookup lookup, Iterable correlation) + public DecorrelatingVisitor(Lookup lookup, Iterable correlation, TypeProvider types) { this.lookup = requireNonNull(lookup, "lookup is null"); this.correlation = ImmutableSet.copyOf(requireNonNull(correlation, "correlation is null")); + this.types = requireNonNull(types, "types is null"); } public Optional decorrelate(PlanNode reference) @@ -327,8 +321,9 @@ public Optional visitProject(ProjectNode node, PlanNode reference) .flatMap(AstUtils::preOrder) .filter(SymbolReference.class::isInstance) .map(SymbolReference.class::cast) - .filter(symbolReference -> !correlation.contains(Symbol.from(symbolReference))) - .forEach(symbolReference -> assignments.putIdentity(Symbol.from(symbolReference))); + .map(symbolReference -> new VariableReferenceExpression(symbolReference.getName(), types.get(Symbol.from(symbolReference)))) + .filter(variable -> !correlation.contains(variable)) + .forEach(assignments::putIdentity); return new Decorrelated( decorrelated.getCorrelatedPredicates(), @@ -376,7 +371,7 @@ private boolean isCorrelatedRecursively(PlanNode node) private boolean isCorrelatedShallowly(PlanNode node) { - return SymbolsExtractor.extractUniqueNonRecursive(node).stream().anyMatch(correlation::contains); + return SymbolsExtractor.extractUniqueNonRecursive(node).stream().map(symbol -> new VariableReferenceExpression(symbol.getName(), types.get(symbol))).anyMatch(correlation::contains); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedLateralJoinToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedLateralJoinToJoin.java index 0d8277db9ae94..cc1477d992609 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedLateralJoinToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedLateralJoinToJoin.java @@ -51,7 +51,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context { PlanNode subquery = lateralJoinNode.getSubquery(); - PlanNodeDecorrelator planNodeDecorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup()); + PlanNodeDecorrelator planNodeDecorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getSymbolAllocator(), context.getLookup()); Optional decorrelatedNodeOptional = planNodeDecorrelator.decorrelateFilters(subquery, lateralJoinNode.getCorrelation()); return decorrelatedNodeOptional.map(decorrelatedNode -> @@ -61,7 +61,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context lateralJoinNode.getInput(), decorrelatedNode.getNode(), ImmutableList.of(), - lateralJoinNode.getOutputSymbols(), + lateralJoinNode.getOutputVariables(), decorrelatedNode.getCorrelatedPredicates().map(OriginalExpressionUtils::castToRowExpression), Optional.empty(), Optional.empty(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java index f25319afec143..a1f212660295b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java @@ -15,9 +15,8 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.BooleanType; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; @@ -33,6 +32,7 @@ import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.WhenClause; import com.google.common.collect.ImmutableList; @@ -40,6 +40,7 @@ import static com.facebook.presto.matching.Pattern.nonEmpty; import static com.facebook.presto.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.StandardTypes.BOOLEAN; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; @@ -115,7 +116,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context lateralJoinNode.getOriginSubqueryError())); } - Symbol unique = context.getSymbolAllocator().newSymbol("unique", BigintType.BIGINT); + VariableReferenceExpression unique = context.getSymbolAllocator().newVariable("unique", BIGINT); LateralJoinNode rewrittenLateralJoinNode = new LateralJoinNode( context.getIdAllocator().getNextId(), @@ -128,19 +129,19 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context lateralJoinNode.getType(), lateralJoinNode.getOriginSubqueryError()); - Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", BooleanType.BOOLEAN); + VariableReferenceExpression isDistinct = context.getSymbolAllocator().newVariable("is_distinct", BooleanType.BOOLEAN); MarkDistinctNode markDistinctNode = new MarkDistinctNode( context.getIdAllocator().getNextId(), rewrittenLateralJoinNode, isDistinct, - rewrittenLateralJoinNode.getInput().getOutputSymbols(), + rewrittenLateralJoinNode.getInput().getOutputVariables(), Optional.empty()); FilterNode filterNode = new FilterNode( context.getIdAllocator().getNextId(), markDistinctNode, castToRowExpression(new SimpleCaseExpression( - isDistinct.toSymbolReference(), + new SymbolReference(isDistinct.getName()), ImmutableList.of( new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), Optional.of(new Cast( @@ -154,6 +155,6 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context return Result.ofPlanNode(new ProjectNode( context.getIdAllocator().getNextId(), filterNode, - Assignments.identity(lateralJoinNode.getOutputSymbols()))); + Assignments.identity(lateralJoinNode.getOutputVariables()))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java index ec83a14e80039..9a33e46b9fbb6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java @@ -68,7 +68,7 @@ public Result apply(LateralJoinNode parent, Captures captures, Context context) } List subqueryProjections = searchFrom(parent.getSubquery(), context.getLookup()) - .where(node -> node instanceof ProjectNode && !node.getOutputSymbols().equals(parent.getCorrelation())) + .where(node -> node instanceof ProjectNode && !node.getOutputVariables().equals(parent.getCorrelation())) .findAll(); if (subqueryProjections.size() == 0) { @@ -76,7 +76,7 @@ public Result apply(LateralJoinNode parent, Captures captures, Context context) } else if (subqueryProjections.size() == 1) { Assignments assignments = Assignments.builder() - .putIdentities(parent.getInput().getOutputSymbols()) + .putIdentities(parent.getInput().getOutputVariables()) .putAll(subqueryProjections.get(0).getAssignments()) .build(); return Result.ofPlanNode(projectNode(parent.getInput(), assignments, context)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 1587ea723cdb9..5d57c6cf13737 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -17,7 +17,7 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.function.StandardFunctionResolution; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -36,6 +36,7 @@ import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -112,14 +113,14 @@ public Result apply(ApplyNode parent, Captures captures, Context context) private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, Context context) { - checkState(applyNode.getSubquery().getOutputSymbols().isEmpty(), "Expected subquery output symbols to be pruned"); + checkState(applyNode.getSubquery().getOutputVariables().isEmpty(), "Expected subquery output variables to be pruned"); - Symbol exists = getOnlyElement(applyNode.getSubqueryAssignments().getSymbols()); - Symbol subqueryTrue = context.getSymbolAllocator().newSymbol("subqueryTrue", BOOLEAN); + VariableReferenceExpression exists = getOnlyElement(applyNode.getSubqueryAssignments().getVariables()); + VariableReferenceExpression subqueryTrue = context.getSymbolAllocator().newVariable("subqueryTrue", BOOLEAN); Assignments.Builder assignments = Assignments.builder(); - assignments.putIdentities(applyNode.getInput().getOutputSymbols()); - assignments.put(exists, new CoalesceExpression(ImmutableList.of(subqueryTrue.toSymbolReference(), BooleanLiteral.FALSE_LITERAL))); + assignments.putIdentities(applyNode.getInput().getOutputVariables()); + assignments.put(exists, new CoalesceExpression(ImmutableList.of(new SymbolReference(subqueryTrue.getName()), BooleanLiteral.FALSE_LITERAL))); PlanNode subquery = new ProjectNode( context.getIdAllocator().getNextId(), @@ -130,7 +131,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C false), Assignments.of(subqueryTrue, TRUE_LITERAL)); - PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup()); + PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getSymbolAllocator(), context.getLookup()); if (!decorrelator.decorrelateFilters(subquery, applyNode.getCorrelation()).isPresent()) { return Optional.empty(); } @@ -148,8 +149,8 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context) { - Symbol count = context.getSymbolAllocator().newSymbol("count", BIGINT); - Symbol exists = getOnlyElement(parent.getSubqueryAssignments().getSymbols()); + VariableReferenceExpression count = context.getSymbolAllocator().newVariable("count", BIGINT); + VariableReferenceExpression exists = getOnlyElement(parent.getSubqueryAssignments().getVariables()); return new LateralJoinNode( parent.getId(), @@ -171,7 +172,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context) AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), - Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString())))), + Assignments.of(exists, new ComparisonExpression(GREATER_THAN, new SymbolReference(count.getName()), new Cast(new LongLiteral("0"), BIGINT.toString())))), parent.getCorrelation(), INNER, parent.getOriginSubqueryError()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java index d844134e172da..362dbc7a2e14f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -15,6 +15,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -77,14 +78,14 @@ public Result apply(ApplyNode applyNode, Captures captures, Context context) } InPredicate inPredicate = (InPredicate) expression; - Symbol semiJoinSymbol = getOnlyElement(applyNode.getSubqueryAssignments().getSymbols()); + VariableReferenceExpression semiJoinVariable = getOnlyElement(applyNode.getSubqueryAssignments().getVariables()); SemiJoinNode replacement = new SemiJoinNode(context.getIdAllocator().getNextId(), applyNode.getInput(), applyNode.getSubquery(), - Symbol.from(inPredicate.getValue()), - Symbol.from(inPredicate.getValueList()), - semiJoinSymbol, + context.getSymbolAllocator().toVariableReference(Symbol.from(inPredicate.getValue())), + context.getSymbolAllocator().toVariableReference(Symbol.from(inPredicate.getValueList())), + semiJoinVariable, Optional.empty(), Optional.empty(), Optional.empty()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java index da2a951396d69..766d353da96b5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java @@ -15,7 +15,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; @@ -48,9 +48,9 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context lateralJoinNode.getInput(), lateralJoinNode.getSubquery(), ImmutableList.of(), - ImmutableList.builder() - .addAll(lateralJoinNode.getInput().getOutputSymbols()) - .addAll(lateralJoinNode.getSubquery().getOutputSymbols()) + ImmutableList.builder() + .addAll(lateralJoinNode.getInput().getOutputVariables()) + .addAll(lateralJoinNode.getSubquery().getOutputVariables()) .build(), Optional.empty(), Optional.empty(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java index fb0e17b024ed2..3a74e8d4a0cee 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java @@ -14,8 +14,9 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -43,10 +44,10 @@ private Util() *

    * If all inputs are used, return Optional.empty() to indicate that no pruning is necessary. */ - public static Optional> pruneInputs(Collection availableInputs, Collection expressions) + public static Optional> pruneInputs(Collection availableInputs, Collection expressions, TypeProvider types) { - Set availableInputsSet = ImmutableSet.copyOf(availableInputs); - Set prunedInputs = Sets.filter(availableInputsSet, SymbolsExtractor.extractUnique(expressions)::contains); + Set availableInputsSet = ImmutableSet.copyOf(availableInputs); + Set prunedInputs = Sets.filter(availableInputsSet, SymbolsExtractor.extractUniqueVariable(expressions, types)::contains); if (prunedInputs.size() == availableInputsSet.size()) { return Optional.empty(); @@ -68,13 +69,13 @@ public static PlanNode transpose(PlanNode parent, PlanNode child) /** * @return If the node has outputs not in permittedOutputs, returns an identity projection containing only those node outputs also in permittedOutputs. */ - public static Optional restrictOutputs(PlanNodeIdAllocator idAllocator, PlanNode node, Set permittedOutputs) + public static Optional restrictOutputs(PlanNodeIdAllocator idAllocator, PlanNode node, Set permittedOutputs) { - List restrictedOutputs = node.getOutputSymbols().stream() + List restrictedOutputs = node.getOutputVariables().stream() .filter(permittedOutputs::contains) .collect(toImmutableList()); - if (restrictedOutputs.size() == node.getOutputSymbols().size()) { + if (restrictedOutputs.size() == node.getOutputVariables().size()) { return Optional.empty(); } @@ -90,9 +91,9 @@ public static Optional restrictOutputs(PlanNodeIdAllocator idAllocator * Returns a present Optional iff at least one child was rewritten. */ @SafeVarargs - public static Optional restrictChildOutputs(PlanNodeIdAllocator idAllocator, PlanNode node, Set... permittedChildOutputsArgs) + public static Optional restrictChildOutputs(PlanNodeIdAllocator idAllocator, PlanNode node, Set... permittedChildOutputsArgs) { - List> permittedChildOutputs = ImmutableList.copyOf(permittedChildOutputsArgs); + List> permittedChildOutputs = ImmutableList.copyOf(permittedChildOutputsArgs); checkArgument( (node.getSources().size() == permittedChildOutputs.size()), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java index 84cefce6b65ed..f419eeb8ede87 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java @@ -17,10 +17,10 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ConstantProperty; import com.facebook.presto.spi.LocalProperty; -import com.facebook.presto.spi.predicate.NullableValue; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningHandle; -import com.facebook.presto.sql.planner.Symbol; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -48,13 +48,13 @@ public class ActualProperties { private final Global global; - private final List> localProperties; - private final Map constants; + private final List> localProperties; + private final Map constants; private ActualProperties( Global global, - List> localProperties, - Map constants) + List> localProperties, + Map constants) { requireNonNull(global, "globalProperties is null"); requireNonNull(localProperties, "localProperties is null"); @@ -64,15 +64,15 @@ private ActualProperties( // The constants field implies a ConstantProperty in localProperties (but not vice versa). // Let's make sure to include the constants into the local constant properties. - Set localConstants = LocalProperties.extractLeadingConstants(localProperties); + Set localConstants = LocalProperties.extractLeadingConstants(localProperties); localProperties = LocalProperties.stripLeadingConstants(localProperties); - Set updatedLocalConstants = ImmutableSet.builder() + Set updatedLocalConstants = ImmutableSet.builder() .addAll(localConstants) .addAll(constants.keySet()) .build(); - List> updatedLocalProperties = LocalProperties.normalizeAndPrune(ImmutableList.>builder() + List> updatedLocalProperties = LocalProperties.normalizeAndPrune(ImmutableList.>builder() .addAll(transform(updatedLocalConstants, ConstantProperty::new)) .addAll(localProperties) .build()); @@ -99,22 +99,22 @@ public boolean isNullsAndAnyReplicated() return global.isNullsAndAnyReplicated(); } - public boolean isStreamPartitionedOn(Collection columns) + public boolean isStreamPartitionedOn(Collection columns) { return isStreamPartitionedOn(columns, false); } - public boolean isStreamPartitionedOn(Collection columns, boolean nullsAndAnyReplicated) + public boolean isStreamPartitionedOn(Collection columns, boolean nullsAndAnyReplicated) { return global.isStreamPartitionedOn(columns, constants.keySet(), nullsAndAnyReplicated); } - public boolean isNodePartitionedOn(Collection columns) + public boolean isNodePartitionedOn(Collection columns) { return isNodePartitionedOn(columns, false); } - public boolean isNodePartitionedOn(Collection columns, boolean nullsAndAnyReplicated) + public boolean isNodePartitionedOn(Collection columns, boolean nullsAndAnyReplicated) { return global.isNodePartitionedOn(columns, constants.keySet(), nullsAndAnyReplicated); } @@ -126,13 +126,13 @@ public boolean isCompatibleTablePartitioningWith(Partitioning partitioning, bool } @Deprecated - public boolean isCompatibleTablePartitioningWith(ActualProperties other, Function> symbolMappings, Metadata metadata, Session session) + public boolean isCompatibleTablePartitioningWith(ActualProperties other, Function> symbolMappings, Metadata metadata, Session session) { return global.isCompatibleTablePartitioningWith( other.global, symbolMappings, - symbol -> Optional.ofNullable(constants.get(symbol)), - symbol -> Optional.ofNullable(other.constants.get(symbol)), + variable -> Optional.ofNullable(constants.get(variable)), + variable -> Optional.ofNullable(other.constants.get(variable)), metadata, session); } @@ -142,13 +142,13 @@ public boolean isRefinedPartitioningOver(Partitioning partitioning, boolean null return global.isRefinedPartitioningOver(partitioning, nullsAndAnyReplicated, metadata, session); } - public boolean isRefinedPartitioningOver(ActualProperties other, Function> symbolMappings, Metadata metadata, Session session) + public boolean isRefinedPartitioningOver(ActualProperties other, Function> symbolMappings, Metadata metadata, Session session) { return global.isRefinedPartitioningOver( other.global, symbolMappings, - symbol -> Optional.ofNullable(constants.get(symbol)), - symbol -> Optional.ofNullable(other.constants.get(symbol)), + variable -> Optional.ofNullable(constants.get(variable)), + variable -> Optional.ofNullable(other.constants.get(variable)), metadata, session); } @@ -164,16 +164,16 @@ public boolean isEffectivelySingleStream() /** * @return true if repartitioning on the keys will yield some difference */ - public boolean isStreamRepartitionEffective(Collection keys) + public boolean isStreamRepartitionEffective(Collection keys) { return global.isStreamRepartitionEffective(keys, constants.keySet()); } - public ActualProperties translate(Function> translator) + public ActualProperties translate(Function> translator) { - Map translatedConstants = new HashMap<>(); - for (Map.Entry entry : constants.entrySet()) { - Optional translatedKey = translator.apply(entry.getKey()); + Map translatedConstants = new HashMap<>(); + for (Map.Entry entry : constants.entrySet()) { + Optional translatedKey = translator.apply(entry.getKey()); if (translatedKey.isPresent()) { translatedConstants.put(translatedKey.get(), entry.getValue()); } @@ -190,12 +190,12 @@ public Optional getNodePartitioning() return global.getNodePartitioning(); } - public Map getConstants() + public Map getConstants() { return constants; } - public List> getLocalProperties() + public List> getLocalProperties() { return localProperties; } @@ -220,8 +220,8 @@ public static Builder builderFrom(ActualProperties properties) public static class Builder { private Global global; - private List> localProperties; - private Map constants; + private List> localProperties; + private Map constants; private boolean unordered; public Builder() @@ -229,7 +229,7 @@ public Builder() this(Global.arbitraryPartition(), ImmutableList.of(), ImmutableMap.of()); } - public Builder(Global global, List> localProperties, Map constants) + public Builder(Global global, List> localProperties, Map constants) { this.global = requireNonNull(global, "global is null"); this.localProperties = ImmutableList.copyOf(localProperties); @@ -248,13 +248,13 @@ public Builder global(ActualProperties other) return this; } - public Builder local(List> localProperties) + public Builder local(List> localProperties) { this.localProperties = ImmutableList.copyOf(localProperties); return this; } - public Builder constants(Map constants) + public Builder constants(Map constants) { this.constants = ImmutableMap.copyOf(constants); return this; @@ -268,7 +268,7 @@ public Builder unordered(boolean unordered) public ActualProperties build() { - List> localProperties = this.localProperties; + List> localProperties = this.localProperties; if (unordered) { localProperties = filteredCopy(this.localProperties, property -> !property.isOrderSensitive()); } @@ -326,8 +326,8 @@ private Global(Optional nodePartitioning, Optional s { checkArgument(!nodePartitioning.isPresent() || !streamPartitioning.isPresent() - || nodePartitioning.get().getColumns().containsAll(streamPartitioning.get().getColumns()) - || streamPartitioning.get().getColumns().containsAll(nodePartitioning.get().getColumns()), + || nodePartitioning.get().getVariableReferences().containsAll(streamPartitioning.get().getVariableReferences()) + || streamPartitioning.get().getVariableReferences().containsAll(nodePartitioning.get().getVariableReferences()), "Global stream partitioning columns should match node partitioning columns"); this.nodePartitioning = requireNonNull(nodePartitioning, "nodePartitioning is null"); this.streamPartitioning = requireNonNull(streamPartitioning, "streamPartitioning is null"); @@ -355,7 +355,10 @@ public static Global arbitraryPartition() return new Global(Optional.empty(), Optional.empty(), false); } - public static Global partitionedOn(PartitioningHandle nodePartitioningHandle, List nodePartitioning, Optional> streamPartitioning) + public static Global partitionedOn( + PartitioningHandle nodePartitioningHandle, + List nodePartitioning, + Optional> streamPartitioning) { return new Global( Optional.of(Partitioning.create(nodePartitioningHandle, nodePartitioning)), @@ -371,7 +374,7 @@ public static Global partitionedOn(Partitioning nodePartitioning, Optional streamPartitioning) + public static Global streamPartitionedOn(List streamPartitioning) { return new Global( Optional.empty(), @@ -410,7 +413,7 @@ private boolean isCoordinatorOnly() return nodePartitioning.get().getHandle().isCoordinatorOnly(); } - private boolean isNodePartitionedOn(Collection columns, Set constants, boolean nullsAndAnyReplicated) + private boolean isNodePartitionedOn(Collection columns, Set constants, boolean nullsAndAnyReplicated) { return nodePartitioning.isPresent() && nodePartitioning.get().isPartitionedOn(columns, constants) && this.nullsAndAnyReplicated == nullsAndAnyReplicated; } @@ -422,9 +425,9 @@ private boolean isCompatibleTablePartitioningWith(Partitioning partitioning, boo private boolean isCompatibleTablePartitioningWith( Global other, - Function> symbolMappings, - Function> leftConstantMapping, - Function> rightConstantMapping, + Function> symbolMappings, + Function> leftConstantMapping, + Function> rightConstantMapping, Metadata metadata, Session session) { @@ -447,9 +450,9 @@ private boolean isRefinedPartitioningOver(Partitioning partitioning, boolean nul private boolean isRefinedPartitioningOver( Global other, - Function> symbolMappings, - Function> leftConstantMapping, - Function> rightConstantMapping, + Function> symbolMappings, + Function> leftConstantMapping, + Function> rightConstantMapping, Metadata metadata, Session session) { @@ -470,7 +473,7 @@ private Optional getNodePartitioning() return nodePartitioning; } - private boolean isStreamPartitionedOn(Collection columns, Set constants, boolean nullsAndAnyReplicated) + private boolean isStreamPartitionedOn(Collection columns, Set constants, boolean nullsAndAnyReplicated) { return streamPartitioning.isPresent() && streamPartitioning.get().isPartitionedOn(columns, constants) && this.nullsAndAnyReplicated == nullsAndAnyReplicated; } @@ -478,7 +481,7 @@ private boolean isStreamPartitionedOn(Collection columns, Set co /** * @return true if all the data will effectively land in a single stream */ - private boolean isEffectivelySingleStream(Set constants) + private boolean isEffectivelySingleStream(Set constants) { return streamPartitioning.isPresent() && streamPartitioning.get().isEffectivelySinglePartition(constants) && !nullsAndAnyReplicated; } @@ -486,12 +489,14 @@ private boolean isEffectivelySingleStream(Set constants) /** * @return true if repartitioning on the keys will yield some difference */ - private boolean isStreamRepartitionEffective(Collection keys, Set constants) + private boolean isStreamRepartitionEffective(Collection keys, Set constants) { return (!streamPartitioning.isPresent() || streamPartitioning.get().isRepartitionEffective(keys, constants)) && !nullsAndAnyReplicated; } - private Global translate(Function> translator, Function> constants) + private Global translate( + Function> translator, + Function> constants) { return new Global( nodePartitioning.flatMap(partitioning -> partitioning.translate(translator, constants)), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 2b3fbd6a74871..39896a5d56105 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialMergePushdownStrategy; import com.facebook.presto.sql.parser.SqlParser; @@ -122,6 +123,7 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.partitionedExchange; import static com.facebook.presto.sql.planner.plan.ExchangeNode.replicatedExchange; import static com.facebook.presto.sql.planner.plan.ExchangeNode.roundRobinExchange; +import static com.facebook.presto.sql.relational.Expressions.variable; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.base.Preconditions.checkArgument; @@ -194,7 +196,7 @@ protected PlanWithProperties visitPlan(PlanNode node, PreferredProperties prefer @Override public PlanWithProperties visitProject(ProjectNode node, PreferredProperties preferredProperties) { - Map identities = computeIdentityTranslations(node.getAssignments()); + Map identities = computeIdentityTranslations(node.getAssignments(), types); PreferredProperties translatedPreferred = preferredProperties.translate(symbol -> Optional.ofNullable(identities.get(symbol))); return rebaseAndDeriveProperties(node, planChild(node, translatedPreferred)); @@ -231,7 +233,7 @@ public PlanWithProperties visitEnforceSingleRow(EnforceSingleRowNode node, Prefe @Override public PlanWithProperties visitAggregation(AggregationNode node, PreferredProperties parentPreferredProperties) { - Set partitioningRequirement = ImmutableSet.copyOf(node.getGroupingKeys()); + Set partitioningRequirement = ImmutableSet.copyOf(node.getGroupingKeys()); boolean preferSingleNode = node.hasSingleNodeExecutionPreference(metadata.getFunctionManager()); PreferredProperties preferredProperties = preferSingleNode ? PreferredProperties.undistributed() : PreferredProperties.any(); @@ -260,7 +262,7 @@ else if (!child.getProperties().isStreamPartitionedOn(partitioningRequirement) & selectExchangeScopeForPartitionedRemoteExchange(child.getNode(), false), child.getNode(), createPartitioning(node.getGroupingKeys()), - node.getHashSymbol()), + node.getHashVariable()), child.getProperties()); } return rebaseAndDeriveProperties(node, child); @@ -269,20 +271,20 @@ else if (!child.getProperties().isStreamPartitionedOn(partitioningRequirement) & @Override public PlanWithProperties visitGroupId(GroupIdNode node, PreferredProperties preferredProperties) { - PreferredProperties childPreference = preferredProperties.translate(translateGroupIdSymbols(node)); + PreferredProperties childPreference = preferredProperties.translate(translateGroupIdVariables(node)); PlanWithProperties child = planChild(node, childPreference); return rebaseAndDeriveProperties(node, child); } - private Function> translateGroupIdSymbols(GroupIdNode node) + private Function> translateGroupIdVariables(GroupIdNode node) { - return symbol -> { - if (node.getAggregationArguments().contains(symbol)) { - return Optional.of(symbol); + return variable -> { + if (node.getAggregationArguments().contains(variable)) { + return Optional.of(variable); } - if (node.getCommonGroupingColumns().contains(symbol)) { - return Optional.of(node.getGroupingColumns().get(symbol)); + if (node.getCommonGroupingColumns().contains(variable)) { + return Optional.of((node.getGroupingColumns().get(variable))); } return Optional.empty(); @@ -292,19 +294,19 @@ private Function> translateGroupIdSymbols(GroupIdNode n @Override public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, PreferredProperties preferredProperties) { - PreferredProperties preferredChildProperties = PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(node.getDistinctSymbols()), grouped(node.getDistinctSymbols())) + PreferredProperties preferredChildProperties = PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(node.getDistinctVariables()), grouped(node.getDistinctVariables())) .mergeWithParent(preferredProperties); PlanWithProperties child = node.getSource().accept(this, preferredChildProperties); if (child.getProperties().isSingleNode() || - !child.getProperties().isStreamPartitionedOn(node.getDistinctSymbols())) { + !child.getProperties().isStreamPartitionedOn(node.getDistinctVariables())) { child = withDerivedProperties( partitionedExchange( idAllocator.getNextId(), selectExchangeScopeForPartitionedRemoteExchange(child.getNode(), false), child.getNode(), - createPartitioning(node.getDistinctSymbols()), - node.getHashSymbol()), + createPartitioning(node.getDistinctVariables()), + node.getHashVariable()), child.getProperties()); } @@ -314,13 +316,13 @@ public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, PreferredProp @Override public PlanWithProperties visitWindow(WindowNode node, PreferredProperties preferredProperties) { - List> desiredProperties = new ArrayList<>(); + List> desiredProperties = new ArrayList<>(); if (!node.getPartitionBy().isEmpty()) { desiredProperties.add(new GroupingProperty<>(node.getPartitionBy())); } node.getOrderingScheme().ifPresent(orderingScheme -> orderingScheme.getOrderBy().stream() - .map(symbol -> new SortingProperty<>(symbol, orderingScheme.getOrdering(symbol))) + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) .forEach(desiredProperties::add)); PlanWithProperties child = planChild( @@ -342,7 +344,7 @@ public PlanWithProperties visitWindow(WindowNode node, PreferredProperties prefe selectExchangeScopeForPartitionedRemoteExchange(child.getNode(), false), child.getNode(), createPartitioning(node.getPartitionBy()), - node.getHashSymbol()), + node.getHashVariable()), child.getProperties()); } } @@ -379,7 +381,7 @@ public PlanWithProperties visitRowNumber(RowNumberNode node, PreferredProperties selectExchangeScopeForPartitionedRemoteExchange(child.getNode(), false), child.getNode(), createPartitioning(node.getPartitionBy()), - node.getHashSymbol()), + node.getHashVariable()), child.getProperties()); } @@ -406,7 +408,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, PreferredPr selectExchangeScopeForPartitionedRemoteExchange(partial, false), partial, createPartitioning(node.getPartitionBy()), - node.getHashSymbol()); + node.getHashVariable()); } PlanWithProperties child = planChild(node, preferredChildProperties); @@ -418,10 +420,10 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, PreferredPr idAllocator.getNextId(), child.getNode(), node.getSpecification(), - node.getRowNumberSymbol(), + node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), true, - node.getHashSymbol()), + node.getHashVariable()), child.getProperties()); child = withDerivedProperties(addExchange.apply(child.getNode()), child.getProperties()); @@ -462,9 +464,9 @@ public PlanWithProperties visitSort(SortNode node, PreferredProperties preferred // current plan so far is single node, so local properties are effectively global properties // skip the SortNode if the local properties guarantee ordering on Sort keys // TODO: This should be extracted as a separate optimizer once the planner is able to reason about the ordering of each operator - List> desiredProperties = new ArrayList<>(); - for (Symbol symbol : node.getOrderingScheme().getOrderBy()) { - desiredProperties.add(new SortingProperty<>(symbol, node.getOrderingScheme().getOrdering(symbol))); + List> desiredProperties = new ArrayList<>(); + for (VariableReferenceExpression variable : node.getOrderingScheme().getOrderBy()) { + desiredProperties.add(new SortingProperty<>(variable, node.getOrderingScheme().getOrdering(variable))); } if (LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).stream() @@ -526,7 +528,7 @@ public PlanWithProperties visitDistinctLimit(DistinctLimitNode node, PreferredPr gatheringExchange( idAllocator.getNextId(), REMOTE_STREAMING, - new DistinctLimitNode(idAllocator.getNextId(), child.getNode(), node.getLimit(), true, node.getDistinctSymbols(), node.getHashSymbol())), + new DistinctLimitNode(idAllocator.getNextId(), child.getNode(), node.getLimit(), true, node.getDistinctVariables(), node.getHashVariable())), child.getProperties()); } @@ -557,10 +559,10 @@ public PlanWithProperties visitTableWriter(TableWriterNode node, PreferredProper Optional partitioningScheme = node.getPartitioningScheme(); if (!partitioningScheme.isPresent()) { if (scaleWriters) { - partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), source.getNode().getOutputSymbols())); + partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), source.getNode().getOutputVariables())); } else if (redistributeWrites) { - partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), source.getNode().getOutputSymbols())); + partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), source.getNode().getOutputVariables())); } } @@ -676,10 +678,10 @@ private Function createDirectTranslator(SetMultimap inputToOutpu @Override public PlanWithProperties visitJoin(JoinNode node, PreferredProperties preferredProperties) { - List leftSymbols = node.getCriteria().stream() + List leftVariables = node.getCriteria().stream() .map(JoinNode.EquiJoinClause::getLeft) .collect(toImmutableList()); - List rightSymbols = node.getCriteria().stream() + List rightVariables = node.getCriteria().stream() .map(JoinNode.EquiJoinClause::getRight) .collect(toImmutableList()); @@ -690,30 +692,30 @@ public PlanWithProperties visitJoin(JoinNode node, PreferredProperties preferred // use partitioned join if probe side is naturally partitioned on join symbols (e.g: because of aggregation) if (!node.getCriteria().isEmpty() - && left.getProperties().isNodePartitionedOn(leftSymbols) && !left.getProperties().isSingleNode()) { - return planPartitionedJoin(node, leftSymbols, rightSymbols, left); + && left.getProperties().isNodePartitionedOn(leftVariables) && !left.getProperties().isSingleNode()) { + return planPartitionedJoin(node, leftVariables, rightVariables, left); } return planReplicatedJoin(node, left); } else { - return planPartitionedJoin(node, leftSymbols, rightSymbols); + return planPartitionedJoin(node, leftVariables, rightVariables); } } - private PlanWithProperties planPartitionedJoin(JoinNode node, List leftSymbols, List rightSymbols) + private PlanWithProperties planPartitionedJoin(JoinNode node, List leftVariables, List rightVariables) { - return planPartitionedJoin(node, leftSymbols, rightSymbols, node.getLeft().accept(this, PreferredProperties.partitioned(ImmutableSet.copyOf(leftSymbols)))); + return planPartitionedJoin(node, leftVariables, rightVariables, node.getLeft().accept(this, PreferredProperties.partitioned(ImmutableSet.copyOf(leftVariables)))); } - private PlanWithProperties planPartitionedJoin(JoinNode node, List leftSymbols, List rightSymbols, PlanWithProperties left) + private PlanWithProperties planPartitionedJoin(JoinNode node, List leftVariables, List rightVariables, PlanWithProperties left) { - SetMultimap rightToLeft = createMapping(rightSymbols, leftSymbols); - SetMultimap leftToRight = createMapping(leftSymbols, rightSymbols); + SetMultimap rightToLeft = createMapping(rightVariables, leftVariables); + SetMultimap leftToRight = createMapping(leftVariables, rightVariables); PlanWithProperties right; - if (left.getProperties().isNodePartitionedOn(leftSymbols) && !left.getProperties().isSingleNode()) { + if (left.getProperties().isNodePartitionedOn(leftVariables) && !left.getProperties().isSingleNode()) { Partitioning rightPartitioning = left.getProperties().translate(createTranslator(leftToRight)).getNodePartitioning().get(); right = node.getRight().accept(this, PreferredProperties.partitioned(rightPartitioning)); if (!right.getProperties().isCompatibleTablePartitioningWith(left.getProperties(), rightToLeft::get, metadata, session) && @@ -725,21 +727,21 @@ private PlanWithProperties planPartitionedJoin(JoinNode node, List leftS idAllocator.getNextId(), selectExchangeScopeForPartitionedRemoteExchange(right.getNode(), false), right.getNode(), - new PartitioningScheme(rightPartitioning, right.getNode().getOutputSymbols())), + new PartitioningScheme(rightPartitioning, right.getNode().getOutputVariables())), right.getProperties()); } } else { - right = node.getRight().accept(this, PreferredProperties.partitioned(ImmutableSet.copyOf(rightSymbols))); + right = node.getRight().accept(this, PreferredProperties.partitioned(ImmutableSet.copyOf(rightVariables))); - if (right.getProperties().isNodePartitionedOn(rightSymbols) && !right.getProperties().isSingleNode()) { + if (right.getProperties().isNodePartitionedOn(rightVariables) && !right.getProperties().isSingleNode()) { Partitioning leftPartitioning = right.getProperties().translate(createTranslator(rightToLeft)).getNodePartitioning().get(); left = withDerivedProperties( partitionedExchange( idAllocator.getNextId(), selectExchangeScopeForPartitionedRemoteExchange(left.getNode(), false), left.getNode(), - new PartitioningScheme(leftPartitioning, left.getNode().getOutputSymbols())), + new PartitioningScheme(leftPartitioning, left.getNode().getOutputVariables())), left.getProperties()); } else { @@ -748,7 +750,7 @@ private PlanWithProperties planPartitionedJoin(JoinNode node, List leftS idAllocator.getNextId(), selectExchangeScopeForPartitionedRemoteExchange(left.getNode(), false), left.getNode(), - createPartitioning(leftSymbols), + createPartitioning(leftVariables), Optional.empty()), left.getProperties()); right = withDerivedProperties( @@ -756,7 +758,7 @@ private PlanWithProperties planPartitionedJoin(JoinNode node, List leftS idAllocator.getNextId(), selectExchangeScopeForPartitionedRemoteExchange(right.getNode(), false), right.getNode(), - createPartitioning(rightSymbols), + createPartitioning(rightVariables), Optional.empty()), right.getProperties()); } @@ -774,7 +776,7 @@ private PlanWithProperties planPartitionedJoin(JoinNode node, List leftS idAllocator.getNextId(), selectExchangeScopeForPartitionedRemoteExchange(right.getNode(), false), right.getNode(), - new PartitioningScheme(rightPartitioning, right.getNode().getOutputSymbols())), + new PartitioningScheme(rightPartitioning, right.getNode().getOutputVariables())), right.getProperties()); } @@ -810,10 +812,10 @@ private PlanWithProperties buildJoin(JoinNode node, PlanWithProperties newLeft, newLeft.getNode(), newRight.getNode(), node.getCriteria(), - node.getOutputSymbols(), + node.getOutputVariables(), node.getFilter(), - node.getLeftHashSymbol(), - node.getRightHashSymbol(), + node.getLeftHashVariable(), + node.getRightHashVariable(), Optional.of(newDistributionType)); return new PlanWithProperties(result, deriveProperties(result, ImmutableList.of(newLeft.getProperties(), newRight.getProperties()))); @@ -843,10 +845,10 @@ public PlanWithProperties visitSpatialJoin(SpatialJoinNode node, PreferredProper } else { left = withDerivedProperties( - partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, left.getNode(), createPartitioning(ImmutableList.of(node.getLeftPartitionSymbol().get())), Optional.empty()), + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, left.getNode(), createPartitioning(ImmutableList.of(node.getLeftPartitionVariable().get())), Optional.empty()), left.getProperties()); right = withDerivedProperties( - partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, right.getNode(), createPartitioning(ImmutableList.of(node.getRightPartitionSymbol().get())), Optional.empty()), + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, right.getNode(), createPartitioning(ImmutableList.of(node.getRightPartitionVariable().get())), Optional.empty()), right.getProperties()); } @@ -857,7 +859,7 @@ public PlanWithProperties visitSpatialJoin(SpatialJoinNode node, PreferredProper @Override public PlanWithProperties visitUnnest(UnnestNode node, PreferredProperties preferredProperties) { - PreferredProperties translatedPreferred = preferredProperties.translate(symbol -> node.getReplicateSymbols().contains(symbol) ? Optional.of(symbol) : Optional.empty()); + PreferredProperties translatedPreferred = preferredProperties.translate(variable -> node.getReplicateVariables().contains(new Symbol(variable.getName())) ? Optional.of(variable) : Optional.empty()); return rebaseAndDeriveProperties(node, planChild(node, translatedPreferred)); } @@ -870,15 +872,15 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, PreferredProperties p SemiJoinNode.DistributionType distributionType = node.getDistributionType().orElseThrow(() -> new IllegalArgumentException("distributionType not yet set")); if (distributionType == SemiJoinNode.DistributionType.PARTITIONED) { - List sourceSymbols = ImmutableList.of(node.getSourceJoinSymbol()); - List filteringSourceSymbols = ImmutableList.of(node.getFilteringSourceJoinSymbol()); + List sourceVariables = ImmutableList.of(node.getSourceJoinVariable()); + List filteringSourceVariables = ImmutableList.of(node.getFilteringSourceJoinVariable()); - SetMultimap sourceToFiltering = createMapping(sourceSymbols, filteringSourceSymbols); - SetMultimap filteringToSource = createMapping(filteringSourceSymbols, sourceSymbols); + SetMultimap sourceToFiltering = createMapping(sourceVariables, filteringSourceVariables); + SetMultimap filteringToSource = createMapping(filteringSourceVariables, sourceVariables); - source = node.getSource().accept(this, PreferredProperties.partitioned(ImmutableSet.copyOf(sourceSymbols))); + source = node.getSource().accept(this, PreferredProperties.partitioned(ImmutableSet.copyOf(sourceVariables))); - if (source.getProperties().isNodePartitionedOn(sourceSymbols) && !source.getProperties().isSingleNode()) { + if (source.getProperties().isNodePartitionedOn(sourceVariables) && !source.getProperties().isSingleNode()) { Partitioning filteringPartitioning = source.getProperties().translate(createTranslator(sourceToFiltering)).getNodePartitioning().get(); filteringSource = node.getFilteringSource().accept(this, PreferredProperties.partitionedWithNullsAndAnyReplicated(filteringPartitioning)); // TODO: Deprecate compatible table partitioning @@ -888,7 +890,7 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, PreferredProperties p filteringSource = withDerivedProperties( partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, filteringSource.getNode(), new PartitioningScheme( filteringPartitioning, - filteringSource.getNode().getOutputSymbols(), + filteringSource.getNode().getOutputVariables(), Optional.empty(), true, Optional.empty())), @@ -896,12 +898,12 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, PreferredProperties p } } else { - filteringSource = node.getFilteringSource().accept(this, PreferredProperties.partitionedWithNullsAndAnyReplicated(ImmutableSet.copyOf(filteringSourceSymbols))); + filteringSource = node.getFilteringSource().accept(this, PreferredProperties.partitionedWithNullsAndAnyReplicated(ImmutableSet.copyOf(filteringSourceVariables))); - if (filteringSource.getProperties().isNodePartitionedOn(filteringSourceSymbols, true) && !filteringSource.getProperties().isSingleNode()) { + if (filteringSource.getProperties().isNodePartitionedOn(filteringSourceVariables, true) && !filteringSource.getProperties().isSingleNode()) { Partitioning sourcePartitioning = filteringSource.getProperties().translate(createTranslator(filteringToSource)).getNodePartitioning().get(); source = withDerivedProperties( - partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, source.getNode(), new PartitioningScheme(sourcePartitioning, source.getNode().getOutputSymbols())), + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, source.getNode(), new PartitioningScheme(sourcePartitioning, source.getNode().getOutputVariables())), source.getProperties()); } else { @@ -910,11 +912,11 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, PreferredProperties p idAllocator.getNextId(), REMOTE_STREAMING, source.getNode(), - createPartitioning(sourceSymbols), + createPartitioning(sourceVariables), Optional.empty()), source.getProperties()); filteringSource = withDerivedProperties( - partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, filteringSource.getNode(), createPartitioning(filteringSourceSymbols), Optional.empty(), true), + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, filteringSource.getNode(), createPartitioning(filteringSourceVariables), Optional.empty(), true), filteringSource.getProperties()); } } @@ -929,7 +931,7 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, PreferredProperties p filteringSource = withDerivedProperties( partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, filteringSource.getNode(), new PartitioningScheme( filteringPartitioning, - filteringSource.getNode().getOutputSymbols(), + filteringSource.getNode().getOutputVariables(), Optional.empty(), true, Optional.empty())), @@ -964,12 +966,12 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, PreferredProperties p @Override public PlanWithProperties visitIndexJoin(IndexJoinNode node, PreferredProperties preferredProperties) { - List joinColumns = node.getCriteria().stream() + List joinColumns = node.getCriteria().stream() .map(IndexJoinNode.EquiJoinClause::getProbe) .collect(toImmutableList()); // Only prefer grouping on join columns if no parent local property preferences - List> desiredLocalProperties = preferredProperties.getLocalProperties().isEmpty() ? grouped(joinColumns) : ImmutableList.of(); + List> desiredLocalProperties = preferredProperties.getLocalProperties().isEmpty() ? grouped(joinColumns) : ImmutableList.of(); PlanWithProperties probeSource = node.getProbeSource().accept(this, PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(joinColumns), desiredLocalProperties) .mergeWithParent(preferredProperties)); @@ -980,7 +982,7 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, PreferredProperties // TODO: allow repartitioning if unpartitioned to increase parallelism if (shouldRepartitionForIndexJoin(joinColumns, preferredProperties, probeProperties)) { probeSource = withDerivedProperties( - partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, probeSource.getNode(), createPartitioning(joinColumns), node.getProbeHashSymbol()), + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, probeSource.getNode(), createPartitioning(joinColumns), node.getProbeHashVariable()), probeProperties); } @@ -991,7 +993,7 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, PreferredProperties return new PlanWithProperties(result, deriveProperties(result, ImmutableList.of(probeSource.getProperties(), indexSource.getProperties()))); } - private boolean shouldRepartitionForIndexJoin(List joinColumns, PreferredProperties parentPreferredProperties, ActualProperties probeProperties) + private boolean shouldRepartitionForIndexJoin(List joinColumns, PreferredProperties parentPreferredProperties, ActualProperties probeProperties) { // See if distributed index joins are enabled if (!distributedIndexJoins) { @@ -1034,9 +1036,9 @@ public PlanWithProperties visitIndexSource(IndexSourceNode node, PreferredProper .build()); } - private Function> outputToInputTranslator(UnionNode node, int sourceIndex) + private Function> outputToInputTranslator(UnionNode node, int sourceIndex, TypeProvider types) { - return symbol -> Optional.of(node.getSymbolMapping().get(symbol).get(sourceIndex)); + return variable -> Optional.of(node.getVariableMapping().get(variable).get(sourceIndex)); } private Partitioning selectUnionPartitioning(UnionNode node, PartitioningProperties parentPreference) @@ -1049,7 +1051,7 @@ private Partitioning selectUnionPartitioning(UnionNode node, PartitioningPropert // Try planning the children to see if any of them naturally produce a partitioning (for now, just select the first) boolean nullsAndAnyReplicated = parentPreference.isNullsAndAnyReplicated(); for (int sourceIndex = 0; sourceIndex < node.getSources().size(); sourceIndex++) { - PartitioningProperties childPartitioning = parentPreference.translate(outputToInputTranslator(node, sourceIndex)).get(); + PreferredProperties.PartitioningProperties childPartitioning = parentPreference.translate(outputToInputTranslator(node, sourceIndex, types)).get(); PreferredProperties childPreferred = PreferredProperties.builder() .global(PreferredProperties.Global.distributed(childPartitioning.withNullsAndAnyReplicated(nullsAndAnyReplicated))) .build(); @@ -1058,7 +1060,9 @@ private Partitioning selectUnionPartitioning(UnionNode node, PartitioningPropert // Theoretically, if all children are single partitioned on the same node we could choose a single // partitioning, but as this only applies to a union of two values nodes, it isn't worth the added complexity if (child.getProperties().isNodePartitionedOn(childPartitioning.getPartitioningColumns(), nullsAndAnyReplicated) && !child.getProperties().isSingleNode()) { - Function> childToParent = createTranslator(createMapping(node.sourceOutputLayout(sourceIndex), node.getOutputSymbols())); + Function> childToParent = createTranslator(createMapping( + node.sourceOutputLayout(sourceIndex), + node.getOutputVariables())); return child.getProperties().translate(childToParent).getNodePartitioning().get(); } } @@ -1079,10 +1083,12 @@ public PlanWithProperties visitUnion(UnionNode node, PreferredProperties parentP Partitioning desiredParentPartitioning = selectUnionPartitioning(node, parentPartitioningProperties); ImmutableList.Builder partitionedSources = ImmutableList.builder(); - ImmutableListMultimap.Builder outputToSourcesMapping = ImmutableListMultimap.builder(); + ImmutableListMultimap.Builder outputToSourcesMapping = ImmutableListMultimap.builder(); for (int sourceIndex = 0; sourceIndex < node.getSources().size(); sourceIndex++) { - Partitioning childPartitioning = desiredParentPartitioning.translate(createDirectTranslator(createMapping(node.getOutputSymbols(), node.sourceOutputLayout(sourceIndex)))); + Partitioning childPartitioning = desiredParentPartitioning.translate(createDirectTranslator(createMapping( + node.getOutputVariables(), + node.sourceOutputLayout(sourceIndex)))); PreferredProperties childPreferred = PreferredProperties.builder() .global(PreferredProperties.Global.distributed(PartitioningProperties.partitioned(childPartitioning) @@ -1100,7 +1106,7 @@ public PlanWithProperties visitUnion(UnionNode node, PreferredProperties parentP source.getNode(), new PartitioningScheme( childPartitioning, - source.getNode().getOutputSymbols(), + source.getNode().getOutputVariables(), Optional.empty(), nullsAndAnyReplicated, Optional.empty())), @@ -1108,15 +1114,14 @@ public PlanWithProperties visitUnion(UnionNode node, PreferredProperties parentP } partitionedSources.add(source.getNode()); - for (int column = 0; column < node.getOutputSymbols().size(); column++) { - outputToSourcesMapping.put(node.getOutputSymbols().get(column), node.sourceOutputLayout(sourceIndex).get(column)); + for (int column = 0; column < node.getOutputVariables().size(); column++) { + outputToSourcesMapping.put(node.getOutputVariables().get(column), node.sourceOutputLayout(sourceIndex).get(column)); } } UnionNode newNode = new UnionNode( node.getId(), partitionedSources.build(), - outputToSourcesMapping.build(), - ImmutableList.copyOf(outputToSourcesMapping.build().keySet())); + outputToSourcesMapping.build()); return new PlanWithProperties( newNode, @@ -1133,10 +1138,10 @@ public PlanWithProperties visitUnion(UnionNode node, PreferredProperties parentP // first, classify children into single node and distributed List singleNodeChildren = new ArrayList<>(); - List> singleNodeOutputLayouts = new ArrayList<>(); + List> singleNodeOutputLayouts = new ArrayList<>(); List distributedChildren = new ArrayList<>(); - List> distributedOutputLayouts = new ArrayList<>(); + List> distributedOutputLayouts = new ArrayList<>(); for (int i = 0; i < node.getSources().size(); i++) { PlanWithProperties child = node.getSources().get(i).accept(this, PreferredProperties.any()); @@ -1165,7 +1170,7 @@ public PlanWithProperties visitUnion(UnionNode node, PreferredProperties parentP idAllocator.getNextId(), GATHER, REMOTE_STREAMING, - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), node.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), node.getOutputVariables()), distributedChildren, distributedOutputLayouts, Optional.empty()); @@ -1175,8 +1180,8 @@ else if (!singleNodeChildren.isEmpty()) { // add a gathering exchange above partitioned inputs and fold it into the set of unpartitioned inputs // NOTE: new symbols for ExchangeNode output are required in order to keep plan logically correct with new local union below - List exchangeOutputLayout = node.getOutputSymbols().stream() - .map(outputSymbol -> symbolAllocator.newSymbol(outputSymbol.getName(), types.get(outputSymbol))) + List exchangeOutputLayout = node.getOutputVariables().stream() + .map(symbolAllocator::newVariable) .collect(toImmutableList()); result = new ExchangeNode( @@ -1189,18 +1194,19 @@ else if (!singleNodeChildren.isEmpty()) { Optional.empty()); singleNodeChildren.add(result); - singleNodeOutputLayouts.add(result.getOutputSymbols()); + // TODO use result.getOutputVariable() after symbol to variable refactoring is done. This is a temporary hack since we know the value should be exchangeOutputLayout. + singleNodeOutputLayouts.add(exchangeOutputLayout); } - ImmutableListMultimap.Builder mappings = ImmutableListMultimap.builder(); - for (int i = 0; i < node.getOutputSymbols().size(); i++) { - for (List outputLayout : singleNodeOutputLayouts) { - mappings.put(node.getOutputSymbols().get(i), outputLayout.get(i)); + ImmutableListMultimap.Builder mappings = ImmutableListMultimap.builder(); + for (int i = 0; i < node.getOutputVariables().size(); i++) { + for (List outputLayout : singleNodeOutputLayouts) { + mappings.put(node.getOutputVariables().get(i), outputLayout.get(i)); } } // add local union for all unpartitioned inputs - result = new UnionNode(node.getId(), singleNodeChildren, mappings.build(), ImmutableList.copyOf(mappings.build().keySet())); + result = new UnionNode(node.getId(), singleNodeChildren, mappings.build()); } else { throw new IllegalStateException("both singleNodeChildren distributedChildren are empty"); @@ -1216,7 +1222,7 @@ else if (!singleNodeChildren.isEmpty()) { private PlanWithProperties arbitraryDistributeUnion( UnionNode node, List distributedChildren, - List> distributedOutputLayouts) + List> distributedOutputLayouts) { // TODO: can we insert LOCAL exchange for one child SOURCE distributed and another HASH distributed? if (getNumberOfTableScans(distributedChildren) == 0) { @@ -1233,7 +1239,7 @@ private PlanWithProperties arbitraryDistributeUnion( idAllocator.getNextId(), REPARTITION, REMOTE_STREAMING, - new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), node.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), node.getOutputVariables()), distributedChildren, distributedOutputLayouts, Optional.empty())); @@ -1297,7 +1303,7 @@ private ActualProperties derivePropertiesRecursively(PlanNode result) return PropertyDerivations.derivePropertiesRecursively(result, metadata, session, types, parser); } - private Partitioning createPartitioning(List partitioningColumns) + private Partitioning createPartitioning(List partitioningColumns) { // TODO: Use SystemTablesMetadata instead of introducing a special case if (GlobalSystemConnector.NAME.equals(partitioningProviderCatalog)) { @@ -1305,7 +1311,7 @@ private Partitioning createPartitioning(List partitioningColumns) } List partitioningTypes = partitioningColumns.stream() - .map(column -> symbolAllocator.getTypes().get(column)) + .map(VariableReferenceExpression::getType) .collect(toImmutableList()); PartitioningHandle partitioningHandle = metadata.getPartitioningHandleForExchange(session, partitioningProviderCatalog, hashPartitionCount, partitioningTypes); return Partitioning.create(partitioningHandle, partitioningColumns); @@ -1315,7 +1321,7 @@ private Partitioning createPartitioning(List partitioningColumns) // materialized exchange is supported for all nodes. private Scope selectExchangeScopeForPartitionedRemoteExchange(PlanNode exchangeSource, boolean nullsAndAnyReplicated) { - if (nullsAndAnyReplicated || exchangeSource.getOutputSymbols().isEmpty()) { + if (nullsAndAnyReplicated || exchangeSource.getOutputVariables().isEmpty()) { // materialized remote exchange is not supported when // * replicateNullsAndAny is needed // * materializing 0 columns input is not supported @@ -1366,27 +1372,32 @@ private boolean canPushdownPartialMergeThroughLowMemoryOperators(PlanNode node) .allMatch(this::canPushdownPartialMergeThroughLowMemoryOperators); } - private static Map computeIdentityTranslations(Assignments assignments) + public static Map computeIdentityTranslations(Assignments assignments, TypeProvider types) { - Map outputToInput = new HashMap<>(); - for (Map.Entry assignment : assignments.getMap().entrySet()) { + Map outputToInput = new HashMap<>(); + for (Map.Entry assignment : assignments.getMap().entrySet()) { if (assignment.getValue() instanceof SymbolReference) { - outputToInput.put(assignment.getKey(), Symbol.from(assignment.getValue())); + outputToInput.put(assignment.getKey(), toVariableReference(Symbol.from(assignment.getValue()), types)); } } return outputToInput; } + public static VariableReferenceExpression toVariableReference(Symbol symbol, TypeProvider types) + { + return variable(symbol.getName(), types.get(symbol)); + } + @VisibleForTesting static Comparator streamingExecutionPreference(PreferredProperties preferred) { // Calculating the matches can be a bit expensive, so cache the results between comparisons - LoadingCache>, List>>> matchCache = CacheBuilder.newBuilder() + LoadingCache>, List>>> matchCache = CacheBuilder.newBuilder() .build(CacheLoader.from(actualProperties -> LocalProperties.match(actualProperties, preferred.getLocalProperties()))); return (actual1, actual2) -> { - List>> matchLayout1 = matchCache.getUnchecked(actual1.getLocalProperties()); - List>> matchLayout2 = matchCache.getUnchecked(actual2.getLocalProperties()); + List>> matchLayout1 = matchCache.getUnchecked(actual1.getLocalProperties()); + List>> matchLayout2 = matchCache.getUnchecked(actual2.getLocalProperties()); return ComparisonChain.start() .compareTrueFirst(hasLocalOptimization(preferred.getLocalProperties(), matchLayout1), hasLocalOptimization(preferred.getLocalProperties(), matchLayout2)) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index a83a18624c536..4608d4c983b4c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -21,10 +21,10 @@ import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; @@ -55,7 +55,6 @@ import com.facebook.presto.sql.planner.plan.WindowNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Iterator; @@ -283,7 +282,7 @@ public PlanWithProperties visitAggregation(AggregationNode node, StreamPreferred return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); } - List groupingKeys = node.getGroupingKeys(); + List groupingKeys = node.getGroupingKeys(); if (node.hasDefaultOutput()) { checkState(node.isDecomposable(metadata.getFunctionManager())); @@ -302,13 +301,13 @@ public PlanWithProperties visitAggregation(AggregationNode node, StreamPreferred } StreamPreferredProperties childRequirements = parentPreferences - .constrainTo(node.getSource().getOutputSymbols()) + .constrainTo(node.getSource().getOutputVariables()) .withDefaultParallelism(session) .withPartitioning(groupingKeys); PlanWithProperties child = planAndEnforce(node.getSource(), childRequirements, childRequirements); - List preGroupedSymbols = ImmutableList.of(); + List preGroupedSymbols = ImmutableList.of(); if (!LocalProperties.match(child.getProperties().getLocalProperties(), LocalProperties.grouped(groupingKeys)).get(0).isPresent()) { // !isPresent() indicates the property was satisfied completely preGroupedSymbols = groupingKeys; @@ -321,8 +320,8 @@ public PlanWithProperties visitAggregation(AggregationNode node, StreamPreferred node.getGroupingSets(), preGroupedSymbols, node.getStep(), - node.getHashSymbol(), - node.getGroupIdSymbol()); + node.getHashVariable(), + node.getGroupIdVariable()); return deriveProperties(result, child.getProperties()); } @@ -331,28 +330,28 @@ public PlanWithProperties visitAggregation(AggregationNode node, StreamPreferred public PlanWithProperties visitWindow(WindowNode node, StreamPreferredProperties parentPreferences) { StreamPreferredProperties childRequirements = parentPreferences - .constrainTo(node.getSource().getOutputSymbols()) + .constrainTo(node.getSource().getOutputVariables()) .withDefaultParallelism(session) .withPartitioning(node.getPartitionBy()); PlanWithProperties child = planAndEnforce(node.getSource(), childRequirements, childRequirements); - List> desiredProperties = new ArrayList<>(); + List> desiredProperties = new ArrayList<>(); if (!node.getPartitionBy().isEmpty()) { desiredProperties.add(new GroupingProperty<>(node.getPartitionBy())); } node.getOrderingScheme().ifPresent(orderingScheme -> orderingScheme.getOrderBy().stream() - .map(symbol -> new SortingProperty<>(symbol, orderingScheme.getOrdering(symbol))) + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) .forEach(desiredProperties::add)); - Iterator>> matchIterator = LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).iterator(); + Iterator>> matchIterator = LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).iterator(); - Set prePartitionedInputs = ImmutableSet.of(); + Set prePartitionedInputs = ImmutableSet.of(); if (!node.getPartitionBy().isEmpty()) { - Optional> groupingRequirement = matchIterator.next(); - Set unPartitionedInputs = groupingRequirement.map(LocalProperty::getColumns).orElse(ImmutableSet.of()); + Optional> groupingRequirement = matchIterator.next(); + Set unPartitionedInputs = groupingRequirement.map(LocalProperty::getColumns).orElse(ImmutableSet.of()); prePartitionedInputs = node.getPartitionBy().stream() - .filter(symbol -> !unPartitionedInputs.contains(symbol)) + .filter(variable -> !unPartitionedInputs.contains(variable)) .collect(toImmutableSet()); } @@ -368,7 +367,7 @@ public PlanWithProperties visitWindow(WindowNode node, StreamPreferredProperties child.getNode(), node.getSpecification(), node.getWindowFunctions(), - node.getHashSymbol(), + node.getHashVariable(), prePartitionedInputs, preSortedOrderPrefix); @@ -380,18 +379,18 @@ public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, StreamPreferr { // mark distinct requires that all data partitioned StreamPreferredProperties childRequirements = parentPreferences - .constrainTo(node.getSource().getOutputSymbols()) + .constrainTo(node.getSource().getOutputVariables()) .withDefaultParallelism(session) - .withPartitioning(node.getDistinctSymbols()); + .withPartitioning(node.getDistinctVariables()); PlanWithProperties child = planAndEnforce(node.getSource(), childRequirements, childRequirements); MarkDistinctNode result = new MarkDistinctNode( node.getId(), child.getNode(), - node.getMarkerSymbol(), - pruneMarkDistinctSymbols(node, child.getProperties().getLocalProperties()), - node.getHashSymbol()); + node.getMarkerVariable(), + pruneMarkDistinctVariables(node, child.getProperties().getLocalProperties()), + node.getHashVariable()); return deriveProperties(result, child.getProperties()); } @@ -417,33 +416,33 @@ public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, StreamPreferr * Ideally, this logic would be encapsulated in a separate rule, but currently no rule other * than AddLocalExchanges can reason about local properties. */ - private List pruneMarkDistinctSymbols(MarkDistinctNode node, List> localProperties) + private List pruneMarkDistinctVariables(MarkDistinctNode node, List> localProperties) { if (localProperties.isEmpty()) { - return node.getDistinctSymbols(); + return node.getDistinctVariables(); } // Identify functional dependencies between distinct symbols: in the list of local properties any constant // symbol is functionally dependent on the set of symbols that appears earlier. - ImmutableSet.Builder redundantSymbolsBuilder = ImmutableSet.builder(); - for (LocalProperty property : localProperties) { + ImmutableSet.Builder redundantVariablesBuilder = ImmutableSet.builder(); + for (LocalProperty property : localProperties) { if (property instanceof ConstantProperty) { - redundantSymbolsBuilder.add(((ConstantProperty) property).getColumn()); + redundantVariablesBuilder.add(((ConstantProperty) property).getColumn()); } - else if (!node.getDistinctSymbols().containsAll(property.getColumns())) { + else if (!node.getDistinctVariables().containsAll(property.getColumns())) { // Ran into a non-distinct symbol. There will be no more symbols that are functionally dependent on distinct symbols exclusively. break; } } - Set redundantSymbols = redundantSymbolsBuilder.build(); - List remainingSymbols = node.getDistinctSymbols().stream() - .filter(symbol -> !redundantSymbols.contains(symbol)) + Set redundantVariables = redundantVariablesBuilder.build(); + List remainingSymbols = node.getDistinctVariables().stream() + .filter(variable -> !redundantVariables.contains(variable)) .collect(toImmutableList()); if (remainingSymbols.isEmpty()) { // This happens when all distinct symbols are constants. // In that case, keep the first symbol (don't drop them all). - return ImmutableList.of(node.getDistinctSymbols().get(0)); + return ImmutableList.of(node.getDistinctVariables().get(0)); } return remainingSymbols; } @@ -523,7 +522,7 @@ public PlanWithProperties visitUnion(UnionNode node, StreamPreferredProperties p .map(PlanWithProperties::getProperties) .collect(toImmutableList()); - List> inputLayouts = new ArrayList<>(sources.size()); + List> inputLayouts = new ArrayList<>(sources.size()); for (int i = 0; i < sources.size(); i++) { inputLayouts.add(node.sourceOutputLayout(i)); } @@ -533,14 +532,14 @@ public PlanWithProperties visitUnion(UnionNode node, StreamPreferredProperties p idAllocator.getNextId(), GATHER, LOCAL, - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), node.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), node.getOutputVariables()), sources, inputLayouts, Optional.empty()); return deriveProperties(exchangeNode, inputProperties); } - Optional> preferredPartitionColumns = preferredProperties.getPartitioningColumns(); + Optional> preferredPartitionColumns = preferredProperties.getPartitioningColumns(); if (preferredPartitionColumns.isPresent()) { ExchangeNode exchangeNode = new ExchangeNode( idAllocator.getNextId(), @@ -548,8 +547,7 @@ public PlanWithProperties visitUnion(UnionNode node, StreamPreferredProperties p LOCAL, new PartitioningScheme( Partitioning.create(FIXED_HASH_DISTRIBUTION, preferredPartitionColumns.get()), - node.getOutputSymbols(), - Optional.empty()), + node.getOutputVariables()), sources, inputLayouts, Optional.empty()); @@ -561,7 +559,7 @@ public PlanWithProperties visitUnion(UnionNode node, StreamPreferredProperties p idAllocator.getNextId(), REPARTITION, LOCAL, - new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), node.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), node.getOutputVariables()), sources, inputLayouts, Optional.empty()); @@ -582,20 +580,22 @@ public PlanWithProperties visitJoin(JoinNode node, StreamPreferredProperties par probe = planAndEnforce( node.getLeft(), fixedParallelism(), - parentPreferences.constrainTo(node.getLeft().getOutputSymbols()).withFixedParallelism()); + parentPreferences.constrainTo(node.getLeft().getOutputVariables()).withFixedParallelism()); } else { probe = planAndEnforce( node.getLeft(), defaultParallelism(session), - parentPreferences.constrainTo(node.getLeft().getOutputSymbols()).withDefaultParallelism(session)); + parentPreferences.constrainTo(node.getLeft().getOutputVariables()).withDefaultParallelism(session)); } // this build consumes the input completely, so we do not pass through parent preferences - List buildHashSymbols = Lists.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight); + List buildHashVariables = node.getCriteria().stream() + .map(JoinNode.EquiJoinClause::getRight) + .collect(toImmutableList()); StreamPreferredProperties buildPreference; if (getTaskConcurrency(session) > 1) { - buildPreference = exactlyPartitionedOn(buildHashSymbols); + buildPreference = exactlyPartitionedOn(buildHashVariables); } else { buildPreference = singleStream(); @@ -611,7 +611,7 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, StreamPreferredProper PlanWithProperties source = planAndEnforce( node.getSource(), defaultParallelism(session), - parentPreferences.constrainTo(node.getSource().getOutputSymbols()).withDefaultParallelism(session)); + parentPreferences.constrainTo(node.getSource().getOutputVariables()).withDefaultParallelism(session)); // this filter source consumes the input completely, so we do not pass through parent preferences PlanWithProperties filteringSource = planAndEnforce(node.getFilteringSource(), singleStream(), singleStream()); @@ -625,7 +625,7 @@ public PlanWithProperties visitSpatialJoin(SpatialJoinNode node, StreamPreferred PlanWithProperties probe = planAndEnforce( node.getLeft(), defaultParallelism(session), - parentPreferences.constrainTo(node.getLeft().getOutputSymbols()) + parentPreferences.constrainTo(node.getLeft().getOutputVariables()) .withDefaultParallelism(session)); PlanWithProperties build = planAndEnforce(node.getRight(), singleStream(), singleStream()); @@ -639,7 +639,7 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, StreamPreferredProp PlanWithProperties probe = planAndEnforce( node.getProbeSource(), defaultParallelism(session), - parentPreferences.constrainTo(node.getProbeSource().getOutputSymbols()).withDefaultParallelism(session)); + parentPreferences.constrainTo(node.getProbeSource().getOutputVariables()).withDefaultParallelism(session)); // index source does not support local parallel and must produce a single stream StreamProperties indexStreamProperties = derivePropertiesRecursively(node.getIndexSource(), metadata, session, types, parser); @@ -660,8 +660,8 @@ private PlanWithProperties planAndEnforceChildren(PlanNode node, StreamPreferred List children = node.getSources().stream() .map(source -> planAndEnforce( source, - requiredProperties.constrainTo(source.getOutputSymbols()), - preferredProperties.constrainTo(source.getOutputSymbols()))) + requiredProperties.constrainTo(source.getOutputVariables()), + preferredProperties.constrainTo(source.getOutputVariables()))) .collect(toImmutableList()); return rebaseAndDeriveProperties(node, children); @@ -670,9 +670,8 @@ private PlanWithProperties planAndEnforceChildren(PlanNode node, StreamPreferred private PlanWithProperties planAndEnforce(PlanNode node, StreamPreferredProperties requiredProperties, StreamPreferredProperties preferredProperties) { // verify properties are in terms of symbols produced by the node - List outputSymbols = node.getOutputSymbols(); - checkArgument(requiredProperties.getPartitioningColumns().map(outputSymbols::containsAll).orElse(true)); - checkArgument(preferredProperties.getPartitioningColumns().map(outputSymbols::containsAll).orElse(true)); + checkArgument(requiredProperties.getPartitioningColumns().map(node.getOutputVariables()::containsAll).orElse(true)); + checkArgument(preferredProperties.getPartitioningColumns().map(node.getOutputVariables()::containsAll).orElse(true)); // plan the node using the preferred properties PlanWithProperties result = node.accept(this, preferredProperties); @@ -695,7 +694,7 @@ private PlanWithProperties enforce(PlanWithProperties planWithProperties, Stream return deriveProperties(exchangeNode, planWithProperties.getProperties()); } - Optional> requiredPartitionColumns = requiredProperties.getPartitioningColumns(); + Optional> requiredPartitionColumns = requiredProperties.getPartitioningColumns(); if (!requiredPartitionColumns.isPresent()) { // unpartitioned parallel streams required return deriveProperties( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AggregationNodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AggregationNodeUtils.java index 162d7ea1a2ef0..f390d2b9a24d5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AggregationNodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AggregationNodeUtils.java @@ -14,8 +14,10 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.metadata.FunctionManager; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.google.common.collect.ImmutableList; @@ -24,6 +26,8 @@ import java.util.Optional; import java.util.Set; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + public class AggregationNodeUtils { private AggregationNodeUtils() {} @@ -44,6 +48,18 @@ public static Set extractUnique(AggregationNode.Aggregation aggregation) ImmutableSet.Builder builder = ImmutableSet.builder(); aggregation.getArguments().forEach(argument -> builder.addAll(SymbolsExtractor.extractAll(argument))); aggregation.getFilter().ifPresent(filter -> builder.addAll(SymbolsExtractor.extractAll(filter))); + aggregation.getOrderBy().ifPresent(orderingScheme -> builder.addAll(orderingScheme.getOrderBy().stream() + .map(VariableReferenceExpression::getName) + .map(Symbol::new) + .collect(toImmutableSet()))); + return builder.build(); + } + + public static Set extractUniqueVariables(AggregationNode.Aggregation aggregation, TypeProvider types) + { + ImmutableSet.Builder builder = ImmutableSet.builder(); + aggregation.getArguments().forEach(argument -> builder.addAll(SymbolsExtractor.extractAllVariable(argument, types))); + aggregation.getFilter().ifPresent(filter -> builder.addAll(SymbolsExtractor.extractAllVariable(filter, types))); aggregation.getOrderBy().ifPresent(orderingScheme -> builder.addAll(orderingScheme.getOrderBy())); return builder.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/BeginTableWrite.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/BeginTableWrite.java index 5ba9ddd764ca7..17689fe2aef16 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/BeginTableWrite.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/BeginTableWrite.java @@ -94,9 +94,9 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext c node.getId(), node.getSource().accept(this, context), writerTarget, - node.getRowCountSymbol(), - node.getFragmentSymbol(), - node.getTableCommitContextSymbol(), + node.getRowCountVariable(), + node.getFragmentVariable(), + node.getTableCommitContextVariable(), node.getColumns(), node.getColumnNames(), node.getPartitioningScheme(), @@ -113,7 +113,7 @@ public PlanNode visitDelete(DeleteNode node, RewriteContext context) rewriteDeleteTableScan(node.getSource(), deleteHandle.getHandle()), deleteHandle, node.getRowId(), - node.getOutputSymbols()); + node.getOutputVariables()); } @Override @@ -129,7 +129,7 @@ public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteCont node.getId(), child, analyzeHandle, - node.getRowCountSymbol(), + node.getRowCountVariable(), node.isRowCountEnabled(), node.getDescriptor()); } @@ -149,7 +149,7 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext c node.getId(), child, newTarget, - node.getRowCountSymbol(), + node.getRowCountVariable(), node.getStatisticsAggregation(), node.getStatisticsAggregationDescriptor()); } @@ -205,7 +205,7 @@ private PlanNode rewriteDeleteTableScan(PlanNode node, TableHandle handle) return new TableScanNode( scan.getId(), layoutResult.getLayout().getNewTableHandle(), - scan.getOutputSymbols(), + scan.getOutputVariables(), scan.getAssignments(), layoutResult.getLayout().getPredicate(), computeEnforced(originalEnforcedConstraint, layoutResult.getUnenforcedConstraint())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java index 31f01bd3d24cd..c63742e3096f1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java @@ -17,8 +17,8 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.SemanticException; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -54,7 +54,7 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym return plan; } - private SemanticException error(List correlation, String originSubqueryError) + private SemanticException error(List correlation, String originSubqueryError) { checkState(!correlation.isEmpty(), "All the non correlated subqueries should be rewritten at this point"); throw new RuntimeException(format(originSubqueryError, "Given correlated subquery is not supported")); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java index fcaea75cbdce9..a7f5ebaff1a97 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java @@ -83,14 +83,14 @@ public ExpressionEquivalence(Metadata metadata, SqlParser sqlParser) public boolean areExpressionsEquivalent(Session session, Expression leftExpression, Expression rightExpression, TypeProvider types) { - Map symbolInput = new HashMap<>(); + Map variableInput = new HashMap<>(); int inputId = 0; for (Entry entry : types.allTypes().entrySet()) { - symbolInput.put(entry.getKey(), inputId); + variableInput.put(new VariableReferenceExpression(entry.getKey().getName(), entry.getValue()), inputId); inputId++; } - RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, types); - RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, types); + RowExpression leftRowExpression = toRowExpression(session, leftExpression, variableInput, types); + RowExpression rightRowExpression = toRowExpression(session, rightExpression, variableInput, types); RowExpression canonicalizedLeft = leftRowExpression.accept(canonicalizationVisitor, null); RowExpression canonicalizedRight = rightRowExpression.accept(canonicalizationVisitor, null); @@ -98,7 +98,7 @@ public boolean areExpressionsEquivalent(Session session, Expression leftExpressi return canonicalizedLeft.equals(canonicalizedRight); } - private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, TypeProvider types) + private RowExpression toRowExpression(Session session, Expression expression, Map variableInput, TypeProvider types) { // replace qualified names with input references since row expressions do not support these @@ -113,7 +113,7 @@ private RowExpression toRowExpression(Session session, Expression expression, Ma WarningCollector.NOOP); // convert to row expression - return translate(expression, expressionTypes, symbolInput, metadata.getFunctionManager(), metadata.getTypeManager(), session, false); + return translate(expression, expressionTypes, variableInput, metadata.getFunctionManager(), metadata.getTypeManager(), session, false); } private static class CanonicalizationVisitor diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index 492f556ec338f..583a02a675359 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -18,10 +18,10 @@ import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; import com.facebook.presto.sql.planner.PartitioningScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -32,7 +32,6 @@ import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode.EquiJoinClause; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; @@ -110,7 +109,7 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym requireNonNull(symbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); if (SystemSessionProperties.isOptimizeHashGenerationEnabled(session)) { - PlanWithProperties result = plan.accept(new Rewriter(idAllocator, symbolAllocator, types), new HashComputationSet()); + PlanWithProperties result = plan.accept(new Rewriter(idAllocator, symbolAllocator), new HashComputationSet()); return result.getNode(); } return plan; @@ -121,13 +120,11 @@ private static class Rewriter { private final PlanNodeIdAllocator idAllocator; private final SymbolAllocator symbolAllocator; - private final TypeProvider types; - private Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, TypeProvider types) + private Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); - this.types = requireNonNull(types, "types is null"); } @Override @@ -139,7 +136,7 @@ protected PlanWithProperties visitPlan(PlanNode node, HashComputationSet parentP @Override public PlanWithProperties visitEnforceSingleRow(EnforceSingleRowNode node, HashComputationSet parentPreference) { - // this plan node can only have a single input symbol, so do not add extra hash symbols + // this plan node can only have a single input variable, so do not add extra hash variables return planSimpleNodeWithProperties(node, new HashComputationSet(), true); } @@ -163,15 +160,16 @@ public PlanWithProperties visitLateralJoin(LateralJoinNode node, HashComputation public PlanWithProperties visitAggregation(AggregationNode node, HashComputationSet parentPreference) { Optional groupByHash = Optional.empty(); + List groupingKeys = node.getGroupingKeys(); if (!node.isStreamable() && !canSkipHashGeneration(node.getGroupingKeys())) { - groupByHash = computeHash(node.getGroupingKeys()); + groupByHash = computeHash(groupingKeys); } - // aggregation does not pass through preferred hash symbols + // aggregation does not pass through preferred hash variables HashComputationSet requiredHashes = new HashComputationSet(groupByHash); PlanWithProperties child = planAndEnforce(node.getSource(), requiredHashes, false, requiredHashes); - Optional hashSymbol = groupByHash.map(child::getRequiredHashSymbol); + Optional hashVariable = groupByHash.map(child::getRequiredHashVariable); return new PlanWithProperties( new AggregationNode( @@ -179,69 +177,69 @@ public PlanWithProperties visitAggregation(AggregationNode node, HashComputation child.getNode(), node.getAggregations(), node.getGroupingSets(), - node.getPreGroupedSymbols(), + node.getPreGroupedVariables(), node.getStep(), - hashSymbol, - node.getGroupIdSymbol()), - hashSymbol.isPresent() ? ImmutableMap.of(groupByHash.get(), hashSymbol.get()) : ImmutableMap.of()); + hashVariable, + node.getGroupIdVariable()), + hashVariable.isPresent() ? ImmutableMap.of(groupByHash.get(), hashVariable.get()) : ImmutableMap.of()); } - private boolean canSkipHashGeneration(List partitionSymbols) + private boolean canSkipHashGeneration(List partitionVariables) { // HACK: bigint grouped aggregation has special operators that do not use precomputed hash, so we can skip hash generation - return partitionSymbols.isEmpty() || (partitionSymbols.size() == 1 && types.get(Iterables.getOnlyElement(partitionSymbols)).equals(BIGINT)); + return partitionVariables.isEmpty() || (partitionVariables.size() == 1 && Iterables.getOnlyElement(partitionVariables).getType().equals(BIGINT)); } @Override public PlanWithProperties visitGroupId(GroupIdNode node, HashComputationSet parentPreference) { - // remove any hash symbols not exported by the source of this node - return planSimpleNodeWithProperties(node, parentPreference.pruneSymbols(node.getSource().getOutputSymbols())); + // remove any hash variables not exported by the source of this node + return planSimpleNodeWithProperties(node, parentPreference.pruneVariables(node.getSource().getOutputVariables())); } @Override public PlanWithProperties visitDistinctLimit(DistinctLimitNode node, HashComputationSet parentPreference) { - // skip hash symbol generation for single bigint - if (canSkipHashGeneration(node.getDistinctSymbols())) { + // skip hash variable generation for single bigint + if (canSkipHashGeneration(node.getDistinctVariables())) { return planSimpleNodeWithProperties(node, parentPreference); } - Optional hashComputation = computeHash(node.getDistinctSymbols()); + Optional hashComputation = computeHash(node.getDistinctVariables()); PlanWithProperties child = planAndEnforce( node.getSource(), new HashComputationSet(hashComputation), false, parentPreference.withHashComputation(node, hashComputation)); - Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get()); + VariableReferenceExpression hashVariable = child.getRequiredHashVariable(hashComputation.get()); - // TODO: we need to reason about how pre-computed hashes from child relate to distinct symbols. We should be able to include any precomputed hash + // TODO: we need to reason about how pre-computed hashes from child relate to distinct variables. We should be able to include any precomputed hash // that's functionally dependent on the distinct field in the set of distinct fields of the new node to be able to propagate it downstream. // Currently, such precomputed hashes will be dropped by this operation. return new PlanWithProperties( - new DistinctLimitNode(node.getId(), child.getNode(), node.getLimit(), node.isPartial(), node.getDistinctSymbols(), Optional.of(hashSymbol)), - ImmutableMap.of(hashComputation.get(), hashSymbol)); + new DistinctLimitNode(node.getId(), child.getNode(), node.getLimit(), node.isPartial(), node.getDistinctVariables(), Optional.of(hashVariable)), + ImmutableMap.of(hashComputation.get(), hashVariable)); } @Override public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, HashComputationSet parentPreference) { - // skip hash symbol generation for single bigint - if (canSkipHashGeneration(node.getDistinctSymbols())) { + // skip hash variable generation for single bigint + if (canSkipHashGeneration(node.getDistinctVariables())) { return planSimpleNodeWithProperties(node, parentPreference); } - Optional hashComputation = computeHash(node.getDistinctSymbols()); + Optional hashComputation = computeHash(node.getDistinctVariables()); PlanWithProperties child = planAndEnforce( node.getSource(), new HashComputationSet(hashComputation), false, parentPreference.withHashComputation(node, hashComputation)); - Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get()); + VariableReferenceExpression hashVariable = child.getRequiredHashVariable(hashComputation.get()); return new PlanWithProperties( - new MarkDistinctNode(node.getId(), child.getNode(), node.getMarkerSymbol(), node.getDistinctSymbols(), Optional.of(hashSymbol)), - child.getHashSymbols()); + new MarkDistinctNode(node.getId(), child.getNode(), node.getMarkerVariable(), node.getDistinctVariables(), Optional.of(hashVariable)), + child.getHashVariables()); } @Override @@ -257,17 +255,17 @@ public PlanWithProperties visitRowNumber(RowNumberNode node, HashComputationSet new HashComputationSet(hashComputation), false, parentPreference.withHashComputation(node, hashComputation)); - Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get()); + VariableReferenceExpression hashVariable = child.getRequiredHashVariable(hashComputation.get()); return new PlanWithProperties( new RowNumberNode( node.getId(), child.getNode(), node.getPartitionBy(), - node.getRowNumberSymbol(), + node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), - Optional.of(hashSymbol)), - child.getHashSymbols()); + Optional.of(hashVariable)), + child.getHashVariables()); } @Override @@ -283,18 +281,18 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, HashComputa new HashComputationSet(hashComputation), false, parentPreference.withHashComputation(node, hashComputation)); - Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get()); + VariableReferenceExpression hashVariable = child.getRequiredHashVariable(hashComputation.get()); return new PlanWithProperties( new TopNRowNumberNode( node.getId(), child.getNode(), node.getSpecification(), - node.getRowNumberSymbol(), + node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), node.isPartial(), - Optional.of(hashSymbol)), - child.getHashSymbols()); + Optional.of(hashVariable)), + child.getHashVariables()); } @Override @@ -302,58 +300,59 @@ public PlanWithProperties visitJoin(JoinNode node, HashComputationSet parentPref { List clauses = node.getCriteria(); if (clauses.isEmpty()) { - // join does not pass through preferred hash symbols since they take more memory and since + // join does not pass through preferred hash variables since they take more memory and since // the join node filters, may take more compute PlanWithProperties left = planAndEnforce(node.getLeft(), new HashComputationSet(), true, new HashComputationSet()); PlanWithProperties right = planAndEnforce(node.getRight(), new HashComputationSet(), true, new HashComputationSet()); - checkState(left.getHashSymbols().isEmpty() && right.getHashSymbols().isEmpty()); + checkState(left.getHashVariables().isEmpty() && right.getHashVariables().isEmpty()); return new PlanWithProperties( replaceChildren(node, ImmutableList.of(left.getNode(), right.getNode())), ImmutableMap.of()); } - // join does not pass through preferred hash symbols since they take more memory and since + // join does not pass through preferred hash variables since they take more memory and since // the join node filters, may take more compute Optional leftHashComputation = computeHash(Lists.transform(clauses, JoinNode.EquiJoinClause::getLeft)); PlanWithProperties left = planAndEnforce(node.getLeft(), new HashComputationSet(leftHashComputation), true, new HashComputationSet(leftHashComputation)); - Symbol leftHashSymbol = left.getRequiredHashSymbol(leftHashComputation.get()); + VariableReferenceExpression leftHashVariable = left.getRequiredHashVariable(leftHashComputation.get()); Optional rightHashComputation = computeHash(Lists.transform(clauses, JoinNode.EquiJoinClause::getRight)); - // drop undesired hash symbols from build to save memory + // drop undesired hash variables from build to save memory PlanWithProperties right = planAndEnforce(node.getRight(), new HashComputationSet(rightHashComputation), true, new HashComputationSet(rightHashComputation)); - Symbol rightHashSymbol = right.getRequiredHashSymbol(rightHashComputation.get()); + VariableReferenceExpression rightHashVariable = right.getRequiredHashVariable(rightHashComputation.get()); - // build map of all hash symbols - // NOTE: Full outer join doesn't use hash symbols - Map allHashSymbols = new HashMap<>(); + // build map of all hash variables + // NOTE: Full outer join doesn't use hash variables + Map allHashVariables = new HashMap<>(); if (node.getType() == INNER || node.getType() == LEFT) { - allHashSymbols.putAll(left.getHashSymbols()); + allHashVariables.putAll(left.getHashVariables()); } if (node.getType() == INNER || node.getType() == RIGHT) { - allHashSymbols.putAll(right.getHashSymbols()); + allHashVariables.putAll(right.getHashVariables()); } - return buildJoinNodeWithPreferredHashes(node, left, right, allHashSymbols, parentPreference, Optional.of(leftHashSymbol), Optional.of(rightHashSymbol)); + return buildJoinNodeWithPreferredHashes(node, left, right, allHashVariables, parentPreference, Optional.of(leftHashVariable), Optional.of(rightHashVariable)); } private PlanWithProperties buildJoinNodeWithPreferredHashes( JoinNode node, PlanWithProperties left, PlanWithProperties right, - Map allHashSymbols, + Map allHashVariables, HashComputationSet parentPreference, - Optional leftHashSymbol, - Optional rightHashSymbol) + Optional leftHashVariable, + Optional rightHashVariable) { - // retain only hash symbols preferred by parent nodes - Map hashSymbolsWithParentPreferences = - allHashSymbols.entrySet() + // retain only hash variables preferred by parent nodes + Map hashVariablesWithParentPreferences = + allHashVariables.entrySet() .stream() .filter(entry -> parentPreference.getHashes().contains(entry.getKey())) .collect(toImmutableMap(Entry::getKey, Entry::getValue)); - List outputSymbols = concat(left.getNode().getOutputSymbols().stream(), right.getNode().getOutputSymbols().stream()) - .filter(symbol -> node.getOutputSymbols().contains(symbol) || hashSymbolsWithParentPreferences.values().contains(symbol)) + List outputVariables = concat(left.getNode().getOutputVariables().stream(), right.getNode().getOutputVariables().stream()) + .filter(variable -> node.getOutputVariables().contains(variable) || + hashVariablesWithParentPreferences.values().contains(variable)) .collect(toImmutableList()); return new PlanWithProperties( @@ -363,42 +362,42 @@ private PlanWithProperties buildJoinNodeWithPreferredHashes( left.getNode(), right.getNode(), node.getCriteria(), - outputSymbols, + outputVariables, node.getFilter(), - leftHashSymbol, - rightHashSymbol, + leftHashVariable, + rightHashVariable, node.getDistributionType()), - hashSymbolsWithParentPreferences); + hashVariablesWithParentPreferences); } @Override public PlanWithProperties visitSemiJoin(SemiJoinNode node, HashComputationSet parentPreference) { - Optional sourceHashComputation = computeHash(ImmutableList.of(node.getSourceJoinSymbol())); + Optional sourceHashComputation = computeHash(ImmutableList.of(node.getSourceJoinVariable())); PlanWithProperties source = planAndEnforce( node.getSource(), new HashComputationSet(sourceHashComputation), true, new HashComputationSet(sourceHashComputation)); - Symbol sourceHashSymbol = source.getRequiredHashSymbol(sourceHashComputation.get()); + VariableReferenceExpression sourceHashVariable = source.getRequiredHashVariable(sourceHashComputation.get()); - Optional filterHashComputation = computeHash(ImmutableList.of(node.getFilteringSourceJoinSymbol())); + Optional filterHashComputation = computeHash(ImmutableList.of(node.getFilteringSourceJoinVariable())); HashComputationSet requiredHashes = new HashComputationSet(filterHashComputation); PlanWithProperties filteringSource = planAndEnforce(node.getFilteringSource(), requiredHashes, true, requiredHashes); - Symbol filteringSourceHashSymbol = filteringSource.getRequiredHashSymbol(filterHashComputation.get()); + VariableReferenceExpression filteringSourceHashVariable = filteringSource.getRequiredHashVariable(filterHashComputation.get()); return new PlanWithProperties( new SemiJoinNode( node.getId(), source.getNode(), filteringSource.getNode(), - node.getSourceJoinSymbol(), - node.getFilteringSourceJoinSymbol(), + node.getSourceJoinVariable(), + node.getFilteringSourceJoinVariable(), node.getSemiJoinOutput(), - Optional.of(sourceHashSymbol), - Optional.of(filteringSourceHashSymbol), + Optional.of(sourceHashVariable), + Optional.of(filteringSourceHashVariable), node.getDistributionType()), - source.getHashSymbols()); + source.getHashVariables()); } @Override @@ -406,8 +405,8 @@ public PlanWithProperties visitSpatialJoin(SpatialJoinNode node, HashComputation { PlanWithProperties left = planAndEnforce(node.getLeft(), new HashComputationSet(), true, new HashComputationSet()); PlanWithProperties right = planAndEnforce(node.getRight(), new HashComputationSet(), true, new HashComputationSet()); - verify(left.getHashSymbols().isEmpty(), "probe side of the spatial join should not include hash symbols"); - verify(right.getHashSymbols().isEmpty(), "build side of the spatial join should not include hash symbols"); + verify(left.getHashVariables().isEmpty(), "probe side of the spatial join should not include hash variables"); + verify(right.getHashVariables().isEmpty(), "build side of the spatial join should not include hash variables"); return new PlanWithProperties( replaceChildren(node, ImmutableList.of(left.getNode(), right.getNode())), ImmutableMap.of()); @@ -418,7 +417,7 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, HashComputationSet { List clauses = node.getCriteria(); - // join does not pass through preferred hash symbols since they take more memory and since + // join does not pass through preferred hash variables since they take more memory and since // the join node filters, may take more compute Optional probeHashComputation = computeHash(Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getProbe)); PlanWithProperties probe = planAndEnforce( @@ -426,19 +425,19 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, HashComputationSet new HashComputationSet(probeHashComputation), true, new HashComputationSet(probeHashComputation)); - Symbol probeHashSymbol = probe.getRequiredHashSymbol(probeHashComputation.get()); + VariableReferenceExpression probeHashVariable = probe.getRequiredHashVariable(probeHashComputation.get()); - Optional indexHashComputation = computeHash(Lists.transform(clauses, EquiJoinClause::getIndex)); + Optional indexHashComputation = computeHash(Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getIndex)); HashComputationSet requiredHashes = new HashComputationSet(indexHashComputation); PlanWithProperties index = planAndEnforce(node.getIndexSource(), requiredHashes, true, requiredHashes); - Symbol indexHashSymbol = index.getRequiredHashSymbol(indexHashComputation.get()); + VariableReferenceExpression indexHashVariable = index.getRequiredHashVariable(indexHashComputation.get()); - // build map of all hash symbols - Map allHashSymbols = new HashMap<>(); + // build map of all hash variables + Map allHashVariables = new HashMap<>(); if (node.getType() == IndexJoinNode.Type.INNER) { - allHashSymbols.putAll(probe.getHashSymbols()); + allHashVariables.putAll(probe.getHashVariables()); } - allHashSymbols.putAll(index.getHashSymbols()); + allHashVariables.putAll(index.getHashVariables()); return new PlanWithProperties( new IndexJoinNode( @@ -447,9 +446,9 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, HashComputationSet probe.getNode(), index.getNode(), node.getCriteria(), - Optional.of(probeHashSymbol), - Optional.of(indexHashSymbol)), - allHashSymbols); + Optional.of(probeHashVariable), + Optional.of(indexHashVariable)), + allHashVariables); } @Override @@ -466,7 +465,7 @@ public PlanWithProperties visitWindow(WindowNode node, HashComputationSet parent true, parentPreference.withHashComputation(node, hashComputation)); - Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get()); + VariableReferenceExpression hashSymbol = child.getRequiredHashVariable(hashComputation.get()); return new PlanWithProperties( new WindowNode( @@ -477,73 +476,73 @@ public PlanWithProperties visitWindow(WindowNode node, HashComputationSet parent Optional.of(hashSymbol), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix()), - child.getHashSymbols()); + child.getHashVariables()); } @Override public PlanWithProperties visitExchange(ExchangeNode node, HashComputationSet parentPreference) { - // remove any hash symbols not exported by this node - HashComputationSet preference = parentPreference.pruneSymbols(node.getOutputSymbols()); + // remove any hash variables not exported by this node + HashComputationSet preference = parentPreference.pruneVariables(node.getOutputVariables()); // Currently, precomputed hash values are only supported for system hash distributions without constants - Optional partitionSymbols = Optional.empty(); + Optional partitionVariables = Optional.empty(); PartitioningScheme partitioningScheme = node.getPartitioningScheme(); if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_HASH_DISTRIBUTION) && partitioningScheme.getPartitioning().getArguments().stream().allMatch(ArgumentBinding::isVariable)) { // add precomputed hash for exchange - partitionSymbols = computeHash(partitioningScheme.getPartitioning().getArguments().stream() - .map(ArgumentBinding::getColumn) + partitionVariables = computeHash(partitioningScheme.getPartitioning().getArguments().stream() + .map(ArgumentBinding::getVariableReference) .collect(toImmutableList())); - preference = preference.withHashComputation(partitionSymbols); + preference = preference.withHashComputation(partitionVariables); } - // establish fixed ordering for hash symbols - List hashSymbolOrder = ImmutableList.copyOf(preference.getHashes()); - Map newHashSymbols = new HashMap<>(); - for (HashComputation preferredHashSymbol : hashSymbolOrder) { - newHashSymbols.put(preferredHashSymbol, symbolAllocator.newHashSymbol()); + // establish fixed ordering for hash variables + List hashVariableOrder = ImmutableList.copyOf(preference.getHashes()); + Map newHashVariables = new HashMap<>(); + for (HashComputation preferredHashVariable : hashVariableOrder) { + newHashVariables.put(preferredHashVariable, symbolAllocator.newHashVariable()); } - // rewrite partition function to include new symbols (and precomputed hash + // rewrite partition function to include new variables (and precomputed hash partitioningScheme = new PartitioningScheme( partitioningScheme.getPartitioning(), - ImmutableList.builder() + ImmutableList.builder() .addAll(partitioningScheme.getOutputLayout()) - .addAll(hashSymbolOrder.stream() - .map(newHashSymbols::get) + .addAll(hashVariableOrder.stream() + .map(newHashVariables::get) .collect(toImmutableList())) .build(), - partitionSymbols.map(newHashSymbols::get), + partitionVariables.map(newHashVariables::get), partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition()); - // add hash symbols to sources - ImmutableList.Builder> newInputs = ImmutableList.builder(); + // add hash variables to sources + ImmutableList.Builder> newInputs = ImmutableList.builder(); ImmutableList.Builder newSources = ImmutableList.builder(); for (int sourceId = 0; sourceId < node.getSources().size(); sourceId++) { PlanNode source = node.getSources().get(sourceId); - List inputSymbols = node.getInputs().get(sourceId); + List inputVariables = node.getInputs().get(sourceId); - Map outputToInputMap = new HashMap<>(); - for (int symbolId = 0; symbolId < inputSymbols.size(); symbolId++) { - outputToInputMap.put(node.getOutputSymbols().get(symbolId), inputSymbols.get(symbolId)); + Map outputToInputMap = new HashMap<>(); + for (int variableId = 0; variableId < inputVariables.size(); variableId++) { + outputToInputMap.put(node.getOutputVariables().get(variableId), inputVariables.get(variableId)); } - Function> outputToInputTranslator = symbol -> Optional.of(outputToInputMap.get(symbol)); + Function> outputToInputTranslator = variable -> Optional.of(outputToInputMap.get(variable)); HashComputationSet sourceContext = preference.translate(outputToInputTranslator); PlanWithProperties child = planAndEnforce(source, sourceContext, true, sourceContext); newSources.add(child.getNode()); - // add hash symbols to inputs in the required order - ImmutableList.Builder newInputSymbols = ImmutableList.builder(); - newInputSymbols.addAll(node.getInputs().get(sourceId)); - for (HashComputation preferredHashSymbol : hashSymbolOrder) { + // add hash variables to inputs in the required order + ImmutableList.Builder newInputVariables = ImmutableList.builder(); + newInputVariables.addAll(inputVariables); + for (HashComputation preferredHashSymbol : hashVariableOrder) { HashComputation hashComputation = preferredHashSymbol.translate(outputToInputTranslator).get(); - newInputSymbols.add(child.getRequiredHashSymbol(hashComputation)); + newInputVariables.add(child.getRequiredHashVariable(hashComputation)); } - newInputs.add(newInputSymbols.build()); + newInputs.add(newInputVariables.build()); } return new PlanWithProperties( @@ -555,41 +554,41 @@ public PlanWithProperties visitExchange(ExchangeNode node, HashComputationSet pa newSources.build(), newInputs.build(), node.getOrderingScheme()), - newHashSymbols); + newHashVariables); } @Override public PlanWithProperties visitUnion(UnionNode node, HashComputationSet parentPreference) { - // remove any hash symbols not exported by this node - HashComputationSet preference = parentPreference.pruneSymbols(node.getOutputSymbols()); + // remove any hash variables not exported by this node + HashComputationSet preference = parentPreference.pruneVariables(node.getOutputVariables()); - // create new hash symbols - Map newHashSymbols = new HashMap<>(); + // create new hash variables + Map newHashVariables = new HashMap<>(); for (HashComputation preferredHashSymbol : preference.getHashes()) { - newHashSymbols.put(preferredHashSymbol, symbolAllocator.newHashSymbol()); + newHashVariables.put(preferredHashSymbol, symbolAllocator.newHashVariable()); } - // add hash symbols to sources - ImmutableListMultimap.Builder newSymbolMapping = ImmutableListMultimap.builder(); - newSymbolMapping.putAll(node.getSymbolMapping()); + // add hash variables to sources + ImmutableListMultimap.Builder newVariableMapping = ImmutableListMultimap.builder(); + newVariableMapping.putAll(node.getVariableMapping()); ImmutableList.Builder newSources = ImmutableList.builder(); for (int sourceId = 0; sourceId < node.getSources().size(); sourceId++) { - // translate preference to input symbols - Map outputToInputMap = new HashMap<>(); - for (Symbol outputSymbol : node.getOutputSymbols()) { - outputToInputMap.put(outputSymbol, node.getSymbolMapping().get(outputSymbol).get(sourceId)); + // translate preference to input variables + Map outputToInputMap = new HashMap<>(); + for (VariableReferenceExpression outputVariables : node.getOutputVariables()) { + outputToInputMap.put(outputVariables, node.getVariableMapping().get(outputVariables).get(sourceId)); } - Function> outputToInputTranslator = symbol -> Optional.of(outputToInputMap.get(symbol)); + Function> outputToInputTranslator = variable -> Optional.of(outputToInputMap.get(variable)); HashComputationSet sourcePreference = preference.translate(outputToInputTranslator); PlanWithProperties child = planAndEnforce(node.getSources().get(sourceId), sourcePreference, true, sourcePreference); newSources.add(child.getNode()); - // add hash symbols to inputs - for (Entry entry : newHashSymbols.entrySet()) { + // add hash variables to inputs + for (Entry entry : newHashVariables.entrySet()) { HashComputation hashComputation = entry.getKey().translate(outputToInputTranslator).get(); - newSymbolMapping.put(entry.getValue(), child.getRequiredHashSymbol(hashComputation)); + newVariableMapping.put(entry.getValue(), child.getRequiredHashVariable(hashComputation)); } } @@ -597,61 +596,60 @@ public PlanWithProperties visitUnion(UnionNode node, HashComputationSet parentPr new UnionNode( node.getId(), newSources.build(), - newSymbolMapping.build(), - ImmutableList.copyOf(newSymbolMapping.build().keySet())), - newHashSymbols); + newVariableMapping.build()), + newHashVariables); } @Override public PlanWithProperties visitProject(ProjectNode node, HashComputationSet parentPreference) { - Map outputToInputMapping = computeIdentityTranslations(node.getAssignments().getMap()); - HashComputationSet sourceContext = parentPreference.translate(symbol -> Optional.ofNullable(outputToInputMapping.get(symbol))); + Map outputToInputMapping = computeIdentityTranslations(node.getAssignments().getMap()); + HashComputationSet sourceContext = parentPreference.translate(variable -> Optional.ofNullable(outputToInputMapping.get(variable))); PlanWithProperties child = plan(node.getSource(), sourceContext); // create a new project node with all assignments from the original node Assignments.Builder newAssignments = Assignments.builder(); newAssignments.putAll(node.getAssignments()); - // and all hash symbols that could be translated to the source symbols - Map allHashSymbols = new HashMap<>(); + // and all hash variables that could be translated to the source variables + Map allHashVariables = new HashMap<>(); for (HashComputation hashComputation : sourceContext.getHashes()) { - Symbol hashSymbol = child.getHashSymbols().get(hashComputation); + VariableReferenceExpression hashVariable = child.getHashVariables().get(hashComputation); Expression hashExpression; - if (hashSymbol == null) { - hashSymbol = symbolAllocator.newHashSymbol(); + if (hashVariable == null) { + hashVariable = symbolAllocator.newHashVariable(); hashExpression = hashComputation.getHashExpression(); } else { - hashExpression = hashSymbol.toSymbolReference(); + hashExpression = new SymbolReference(hashVariable.getName()); } - newAssignments.put(hashSymbol, hashExpression); - allHashSymbols.put(hashComputation, hashSymbol); + newAssignments.put(hashVariable, hashExpression); + allHashVariables.put(hashComputation, hashVariable); } - return new PlanWithProperties(new ProjectNode(node.getId(), child.getNode(), newAssignments.build()), allHashSymbols); + return new PlanWithProperties(new ProjectNode(node.getId(), child.getNode(), newAssignments.build()), allHashVariables); } @Override public PlanWithProperties visitUnnest(UnnestNode node, HashComputationSet parentPreference) { - PlanWithProperties child = plan(node.getSource(), parentPreference.pruneSymbols(node.getSource().getOutputSymbols())); + PlanWithProperties child = plan(node.getSource(), parentPreference.pruneVariables(node.getSource().getOutputVariables())); - // only pass through hash symbols requested by the parent - Map hashSymbols = new HashMap<>(child.getHashSymbols()); - hashSymbols.keySet().retainAll(parentPreference.getHashes()); + // only pass through hash variables requested by the parent + Map hashVariables = new HashMap<>(child.getHashVariables()); + hashVariables.keySet().retainAll(parentPreference.getHashes()); return new PlanWithProperties( new UnnestNode( node.getId(), child.getNode(), - ImmutableList.builder() - .addAll(node.getReplicateSymbols()) - .addAll(hashSymbols.values()) + ImmutableList.builder() + .addAll(node.getReplicateVariables()) + .addAll(hashVariables.values()) .build(), - node.getUnnestSymbols(), - node.getOrdinalitySymbol()), - hashSymbols); + node.getUnnestVariables(), + node.getOrdinalityVariable()), + hashVariables); } private PlanWithProperties planSimpleNodeWithProperties(PlanNode node, HashComputationSet preferredHashes) @@ -662,37 +660,37 @@ private PlanWithProperties planSimpleNodeWithProperties(PlanNode node, HashCompu private PlanWithProperties planSimpleNodeWithProperties( PlanNode node, HashComputationSet preferredHashes, - boolean alwaysPruneExtraHashSymbols) + boolean alwaysPruneExtraHashVariables) { if (node.getSources().isEmpty()) { return new PlanWithProperties(node, ImmutableMap.of()); } - // There is not requirement to produce hash symbols and only preference for symbols - PlanWithProperties source = planAndEnforce(Iterables.getOnlyElement(node.getSources()), new HashComputationSet(), alwaysPruneExtraHashSymbols, preferredHashes); + // There is not requirement to produce hash variables and only preference for variables + PlanWithProperties source = planAndEnforce(Iterables.getOnlyElement(node.getSources()), new HashComputationSet(), alwaysPruneExtraHashVariables, preferredHashes); PlanNode result = replaceChildren(node, ImmutableList.of(source.getNode())); - // return only hash symbols that are passed through the new node - Map hashSymbols = new HashMap<>(source.getHashSymbols()); - hashSymbols.values().retainAll(result.getOutputSymbols()); + // return only hash variables that are passed through the new node + Map hashVariables = new HashMap<>(source.getHashVariables()); + hashVariables.values().retainAll(result.getOutputVariables()); - return new PlanWithProperties(result, hashSymbols); + return new PlanWithProperties(result, hashVariables); } private PlanWithProperties planAndEnforce( PlanNode node, HashComputationSet requiredHashes, - boolean pruneExtraHashSymbols, + boolean pruneExtraHashVariables, HashComputationSet preferredHashes) { PlanWithProperties result = plan(node, preferredHashes); boolean preferenceSatisfied; - if (pruneExtraHashSymbols) { + if (pruneExtraHashVariables) { // Make sure that // (1) result has all required hashes // (2) any extra hashes are preferred hashes (e.g. no pruning is needed) - Set resultHashes = result.getHashSymbols().keySet(); + Set resultHashes = result.getHashVariables().keySet(); Set requiredAndPreferredHashes = ImmutableSet.builder() .addAll(requiredHashes.getHashes()) .addAll(preferredHashes.getHashes()) @@ -701,7 +699,7 @@ private PlanWithProperties planAndEnforce( requiredAndPreferredHashes.containsAll(resultHashes); } else { - preferenceSatisfied = result.getHashSymbols().keySet().containsAll(requiredHashes.getHashes()); + preferenceSatisfied = result.getHashVariables().keySet().containsAll(requiredHashes.getHashes()); } if (preferenceSatisfied) { @@ -715,41 +713,41 @@ private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashCo { Assignments.Builder assignments = Assignments.builder(); - Map outputHashSymbols = new HashMap<>(); + Map outputHashVariables = new HashMap<>(); - // copy through all symbols from child, except for hash symbols not needed by the parent - Map resultHashSymbols = planWithProperties.getHashSymbols().inverse(); - for (Symbol symbol : planWithProperties.getNode().getOutputSymbols()) { - HashComputation partitionSymbols = resultHashSymbols.get(symbol); - if (partitionSymbols == null || requiredHashes.getHashes().contains(partitionSymbols)) { - assignments.put(symbol, symbol.toSymbolReference()); + // copy through all variables from child, except for hash variables not needed by the parent + Map resultHashVariables = planWithProperties.getHashVariables().inverse(); + for (VariableReferenceExpression variable : planWithProperties.getNode().getOutputVariables()) { + HashComputation partitionVariables = resultHashVariables.get(variable); + if (partitionVariables == null || requiredHashes.getHashes().contains(partitionVariables)) { + assignments.put(variable, new SymbolReference(variable.getName())); - if (partitionSymbols != null) { - outputHashSymbols.put(partitionSymbols, symbol); + if (partitionVariables != null) { + outputHashVariables.put(partitionVariables, planWithProperties.getHashVariables().get(partitionVariables)); } } } - // add new projections for hash symbols needed by the parent + // add new projections for hash variables needed by the parent for (HashComputation hashComputation : requiredHashes.getHashes()) { - if (!planWithProperties.getHashSymbols().containsKey(hashComputation)) { + if (!planWithProperties.getHashVariables().containsKey(hashComputation)) { Expression hashExpression = hashComputation.getHashExpression(); - Symbol hashSymbol = symbolAllocator.newHashSymbol(); - assignments.put(hashSymbol, hashExpression); - outputHashSymbols.put(hashComputation, hashSymbol); + VariableReferenceExpression hashVariable = symbolAllocator.newHashVariable(); + assignments.put(hashVariable, hashExpression); + outputHashVariables.put(hashComputation, hashVariable); } } ProjectNode projectNode = new ProjectNode(idAllocator.getNextId(), planWithProperties.getNode(), assignments.build()); - return new PlanWithProperties(projectNode, outputHashSymbols); + return new PlanWithProperties(projectNode, outputHashVariables); } private PlanWithProperties plan(PlanNode node, HashComputationSet parentPreference) { PlanWithProperties result = node.accept(this, parentPreference); checkState( - result.getNode().getOutputSymbols().containsAll(result.getHashSymbols().values()), - "Node %s declares hash symbols not in the output", + result.getNode().getOutputVariables().containsAll(result.getHashVariables().values()), + "Node %s declares hash variables not in the output", result.getNode().getClass().getSimpleName()); return result; } @@ -786,15 +784,15 @@ public Set getHashes() return hashes; } - public HashComputationSet pruneSymbols(List symbols) + public HashComputationSet pruneVariables(List variables) { - Set uniqueSymbols = ImmutableSet.copyOf(symbols); + Set uniqueVariables = ImmutableSet.copyOf(variables); return new HashComputationSet(hashes.stream() - .filter(hash -> hash.canComputeWith(uniqueSymbols)) + .filter(hash -> hash.canComputeWith(uniqueVariables)) .collect(toImmutableSet())); } - public HashComputationSet translate(Function> translator) + public HashComputationSet translate(Function> translator) { Set newHashes = hashes.stream() .map(hash -> hash.translate(translator)) @@ -807,7 +805,7 @@ public HashComputationSet translate(Function> translato public HashComputationSet withHashComputation(PlanNode node, Optional hashComputation) { - return pruneSymbols(node.getOutputSymbols()).withHashComputation(hashComputation); + return pruneVariables(node.getOutputVariables()).withHashComputation(hashComputation); } public HashComputationSet withHashComputation(Optional hashComputation) @@ -822,29 +820,29 @@ public HashComputationSet withHashComputation(Optional hashComp } } - public static Optional computeHash(Iterable fields) + public static Optional computeHash(Iterable fields) { requireNonNull(fields, "fields is null"); - List symbols = ImmutableList.copyOf(fields); - if (symbols.isEmpty()) { + List variables = ImmutableList.copyOf(fields); + if (variables.isEmpty()) { return Optional.empty(); } return Optional.of(new HashComputation(fields)); } - public static Optional getHashExpression(List symbols) + public static Optional getHashExpression(List variables) { - if (symbols.isEmpty()) { + if (variables.isEmpty()) { return Optional.empty(); } Expression result = new GenericLiteral(StandardTypes.BIGINT, String.valueOf(INITIAL_HASH_VALUE)); - for (Symbol symbol : symbols) { + for (VariableReferenceExpression variable : variables) { Expression hashField = new FunctionCall( QualifiedName.of(HASH_CODE), Optional.empty(), false, - ImmutableList.of(new SymbolReference(symbol.getName()))); + ImmutableList.of(new SymbolReference(variable.getName()))); hashField = new CoalesceExpression(hashField, new LongLiteral(String.valueOf(NULL_HASH_CODE))); @@ -855,34 +853,29 @@ public static Optional getHashExpression(List symbols) private static class HashComputation { - private final List fields; + private final List fields; - private HashComputation(Iterable fields) + private HashComputation(Iterable fields) { requireNonNull(fields, "fields is null"); this.fields = ImmutableList.copyOf(fields); checkArgument(!this.fields.isEmpty(), "fields can not be empty"); } - public List getFields() + public Optional translate(Function> translator) { - return fields; - } - - public Optional translate(Function> translator) - { - ImmutableList.Builder newSymbols = ImmutableList.builder(); - for (Symbol field : fields) { - Optional newSymbol = translator.apply(field); - if (!newSymbol.isPresent()) { + ImmutableList.Builder newVariables = ImmutableList.builder(); + for (VariableReferenceExpression field : fields) { + Optional newVariable = translator.apply(field); + if (!newVariable.isPresent()) { return Optional.empty(); } - newSymbols.add(newSymbol.get()); + newVariables.add(newVariable.get()); } - return computeHash(newSymbols.build()); + return computeHash(newVariables.build()); } - public boolean canComputeWith(Set availableFields) + public boolean canComputeWith(Set availableFields) { return availableFields.containsAll(fields); } @@ -890,19 +883,19 @@ public boolean canComputeWith(Set availableFields) private Expression getHashExpression() { Expression hashExpression = new GenericLiteral(StandardTypes.BIGINT, Integer.toString(INITIAL_HASH_VALUE)); - for (Symbol field : fields) { + for (VariableReferenceExpression field : fields) { hashExpression = getHashFunctionCall(hashExpression, field); } return hashExpression; } - private static Expression getHashFunctionCall(Expression previousHashValue, Symbol symbol) + private static Expression getHashFunctionCall(Expression previousHashValue, VariableReferenceExpression variable) { FunctionCall functionCall = new FunctionCall( QualifiedName.of(HASH_CODE), Optional.empty(), false, - ImmutableList.of(symbol.toSymbolReference())); + ImmutableList.of(new SymbolReference(variable.getName()))); List arguments = ImmutableList.of(previousHashValue, orNullHashCode(functionCall)); return new FunctionCall(QualifiedName.of("combine_hash"), arguments); } @@ -943,12 +936,12 @@ public String toString() private static class PlanWithProperties { private final PlanNode node; - private final BiMap hashSymbols; + private final BiMap hashVariables; - public PlanWithProperties(PlanNode node, Map hashSymbols) + public PlanWithProperties(PlanNode node, Map hashVariables) { this.node = requireNonNull(node, "node is null"); - this.hashSymbols = ImmutableBiMap.copyOf(requireNonNull(hashSymbols, "hashSymbols is null")); + this.hashVariables = ImmutableBiMap.copyOf(requireNonNull(hashVariables, "hashVariables is null")); } public PlanNode getNode() @@ -956,25 +949,25 @@ public PlanNode getNode() return node; } - public BiMap getHashSymbols() + public BiMap getHashVariables() { - return hashSymbols; + return hashVariables; } - public Symbol getRequiredHashSymbol(HashComputation hash) + public VariableReferenceExpression getRequiredHashVariable(HashComputation hash) { - Symbol hashSymbol = hashSymbols.get(hash); - requireNonNull(hashSymbol, () -> "No hash symbol generated for " + hash); - return hashSymbol; + VariableReferenceExpression hashVariable = hashVariables.get(hash); + requireNonNull(hashVariable, () -> "No hash variable generated for " + hash); + return hashVariable; } } - private static Map computeIdentityTranslations(Map assignments) + private static Map computeIdentityTranslations(Map assignments) { - Map outputToInput = new HashMap<>(); - for (Map.Entry assignment : assignments.entrySet()) { + Map outputToInput = new HashMap<>(); + for (Map.Entry assignment : assignments.entrySet()) { if (assignment.getValue() instanceof SymbolReference) { - outputToInput.put(assignment.getKey(), Symbol.from(assignment.getValue())); + outputToInput.put(assignment.getKey(), new VariableReferenceExpression(((SymbolReference) assignment.getValue()).getName(), assignment.getKey().getType())); } } return outputToInput; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java index 13b2830b80dde..49729e6f4a3de 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java @@ -18,10 +18,10 @@ import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -150,19 +150,19 @@ public PlanNode visitIntersect(IntersectNode node, RewriteContext rewriteC .map(rewriteContext::rewrite) .collect(toList()); - List markers = allocateSymbols(sources.size(), MARKER, BOOLEAN); + List markers = allocateVariables(sources.size(), MARKER, BOOLEAN); // identity projection for all the fields in each of the sources plus marker columns List withMarkers = appendMarkers(markers, sources, node); // add a union over all the rewritten sources. The outputs of the union have the same name as the // original intersect node - List outputs = node.getOutputSymbols(); + List outputs = node.getOutputVariables(); UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers))); // add count aggregations and filter rows where any of the counts is >= 1 - List aggregationOutputs = allocateSymbols(markers.size(), "count", BIGINT); - AggregationNode aggregation = computeCounts(union, outputs, markers, aggregationOutputs); + List aggregationOutputs = allocateVariables(markers.size(), "count", BIGINT); + AggregationNode aggregation = computeCounts(union, node.getOutputVariables(), markers, aggregationOutputs); FilterNode filterNode = addFilterForIntersect(aggregation); return project(filterNode, outputs); @@ -175,81 +175,81 @@ public PlanNode visitExcept(ExceptNode node, RewriteContext rewriteContext .map(rewriteContext::rewrite) .collect(toList()); - List markers = allocateSymbols(sources.size(), MARKER, BOOLEAN); + List markers = allocateVariables(sources.size(), MARKER, BOOLEAN); // identity projection for all the fields in each of the sources plus marker columns List withMarkers = appendMarkers(markers, sources, node); // add a union over all the rewritten sources. The outputs of the union have the same name as the // original except node - List outputs = node.getOutputSymbols(); + List outputs = node.getOutputVariables(); UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers))); // add count aggregations and filter rows where count for the first source is >= 1 and all others are 0 - List aggregationOutputs = allocateSymbols(markers.size(), "count", BIGINT); - AggregationNode aggregation = computeCounts(union, outputs, markers, aggregationOutputs); + List aggregationOutputs = allocateVariables(markers.size(), "count", BIGINT); + AggregationNode aggregation = computeCounts(union, node.getOutputVariables(), markers, aggregationOutputs); FilterNode filterNode = addFilterForExcept(aggregation, aggregationOutputs.get(0), aggregationOutputs.subList(1, aggregationOutputs.size())); return project(filterNode, outputs); } - private List allocateSymbols(int count, String nameHint, Type type) + private List allocateVariables(int count, String nameHint, Type type) { - ImmutableList.Builder symbolsBuilder = ImmutableList.builder(); + ImmutableList.Builder variablesBuilder = ImmutableList.builder(); for (int i = 0; i < count; i++) { - symbolsBuilder.add(symbolAllocator.newSymbol(nameHint, type)); + variablesBuilder.add(symbolAllocator.newVariable(nameHint, type)); } - return symbolsBuilder.build(); + return variablesBuilder.build(); } - private List appendMarkers(List markers, List nodes, SetOperationNode node) + private List appendMarkers(List markers, List nodes, SetOperationNode node) { ImmutableList.Builder result = ImmutableList.builder(); for (int i = 0; i < nodes.size(); i++) { - result.add(appendMarkers(nodes.get(i), i, markers, Maps.transformValues(node.sourceSymbolMap(i), Symbol::toSymbolReference))); + result.add(appendMarkers(nodes.get(i), i, markers, Maps.transformValues(node.sourceVariableMap(i), variable -> new SymbolReference(variable.getName())))); } return result.build(); } - private PlanNode appendMarkers(PlanNode source, int markerIndex, List markers, Map projections) + private PlanNode appendMarkers(PlanNode source, int markerIndex, List markers, Map projections) { Assignments.Builder assignments = Assignments.builder(); // add existing intersect symbols to projection - for (Map.Entry entry : projections.entrySet()) { - Symbol symbol = symbolAllocator.newSymbol(entry.getKey().getName(), symbolAllocator.getTypes().get(entry.getKey())); - assignments.put(symbol, entry.getValue()); + for (Map.Entry entry : projections.entrySet()) { + VariableReferenceExpression variable = symbolAllocator.newVariable(entry.getKey().getName(), entry.getKey().getType()); + assignments.put(variable, entry.getValue()); } // add extra marker fields to the projection for (int i = 0; i < markers.size(); ++i) { Expression expression = (i == markerIndex) ? TRUE_LITERAL : new Cast(new NullLiteral(), StandardTypes.BOOLEAN); - assignments.put(symbolAllocator.newSymbol(markers.get(i).getName(), BOOLEAN), expression); + assignments.put(symbolAllocator.newVariable(markers.get(i).getName(), BOOLEAN), expression); } return new ProjectNode(idAllocator.getNextId(), source, assignments.build()); } - private UnionNode union(List nodes, List outputs) + private UnionNode union(List nodes, List outputs) { - ImmutableListMultimap.Builder outputsToInputs = ImmutableListMultimap.builder(); + ImmutableListMultimap.Builder outputsToInputs = ImmutableListMultimap.builder(); for (PlanNode source : nodes) { - for (int i = 0; i < source.getOutputSymbols().size(); i++) { - outputsToInputs.put(outputs.get(i), source.getOutputSymbols().get(i)); + for (int i = 0; i < source.getOutputVariables().size(); i++) { + outputsToInputs.put(outputs.get(i), source.getOutputVariables().get(i)); } } - return new UnionNode(idAllocator.getNextId(), nodes, outputsToInputs.build(), outputs); + return new UnionNode(idAllocator.getNextId(), nodes, outputsToInputs.build()); } - private AggregationNode computeCounts(UnionNode sourceNode, List originalColumns, List markers, List aggregationOutputs) + private AggregationNode computeCounts(UnionNode sourceNode, List originalColumns, List markers, List aggregationOutputs) { - ImmutableMap.Builder aggregations = ImmutableMap.builder(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); for (int i = 0; i < markers.size(); i++) { - Symbol output = aggregationOutputs.get(i); + VariableReferenceExpression output = aggregationOutputs.get(i); aggregations.put(output, new Aggregation( functionResolution.countFunction(BIGINT), - ImmutableList.of(markers.get(i).toSymbolReference()), + ImmutableList.of(new SymbolReference(markers.get(i).getName())), Optional.empty(), Optional.empty(), false, @@ -269,23 +269,23 @@ private AggregationNode computeCounts(UnionNode sourceNode, List origina private FilterNode addFilterForIntersect(AggregationNode aggregation) { ImmutableList predicates = aggregation.getAggregations().keySet().stream() - .map(column -> new ComparisonExpression(GREATER_THAN_OR_EQUAL, column.toSymbolReference(), new GenericLiteral("BIGINT", "1"))) + .map(column -> new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(column.getName()), new GenericLiteral("BIGINT", "1"))) .collect(toImmutableList()); return new FilterNode(idAllocator.getNextId(), aggregation, castToRowExpression(ExpressionUtils.and(predicates))); } - private FilterNode addFilterForExcept(AggregationNode aggregation, Symbol firstSource, List remainingSources) + private FilterNode addFilterForExcept(AggregationNode aggregation, VariableReferenceExpression firstSource, List remainingSources) { ImmutableList.Builder predicatesBuilder = ImmutableList.builder(); - predicatesBuilder.add(new ComparisonExpression(GREATER_THAN_OR_EQUAL, firstSource.toSymbolReference(), new GenericLiteral("BIGINT", "1"))); - for (Symbol symbol : remainingSources) { - predicatesBuilder.add(new ComparisonExpression(EQUAL, symbol.toSymbolReference(), new GenericLiteral("BIGINT", "0"))); + predicatesBuilder.add(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference(firstSource.getName()), new GenericLiteral("BIGINT", "1"))); + for (VariableReferenceExpression variable : remainingSources) { + predicatesBuilder.add(new ComparisonExpression(EQUAL, new SymbolReference(variable.getName()), new GenericLiteral("BIGINT", "0"))); } return new FilterNode(idAllocator.getNextId(), aggregation, castToRowExpression(ExpressionUtils.and(predicatesBuilder.build()))); } - private ProjectNode project(PlanNode node, List columns) + private ProjectNode project(PlanNode node, List columns) { return new ProjectNode( idAllocator.getNextId(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java index 2ef1ee679211e..9fac636368731 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.ExpressionDomainTranslator; import com.facebook.presto.sql.planner.LiteralEncoder; import com.facebook.presto.sql.planner.Symbol; @@ -42,7 +43,6 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.base.Functions; -import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -113,33 +113,33 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) PlanNode rightRewritten = context.rewrite(node.getRight()); if (!node.getCriteria().isEmpty()) { // Index join only possible with JOIN criteria - List leftJoinSymbols = Lists.transform(node.getCriteria(), JoinNode.EquiJoinClause::getLeft); - List rightJoinSymbols = Lists.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight); + List leftJoinVariables = Lists.transform(node.getCriteria(), JoinNode.EquiJoinClause::getLeft); + List rightJoinVariables = Lists.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight); Optional leftIndexCandidate = IndexSourceRewriter.rewriteWithIndex( leftRewritten, - ImmutableSet.copyOf(leftJoinSymbols), + ImmutableSet.copyOf(leftJoinVariables), symbolAllocator, idAllocator, metadata, session); if (leftIndexCandidate.isPresent()) { // Sanity check that we can trace the path for the index lookup key - Map trace = IndexKeyTracer.trace(leftIndexCandidate.get(), ImmutableSet.copyOf(leftJoinSymbols)); - checkState(!trace.isEmpty() && leftJoinSymbols.containsAll(trace.keySet())); + Map trace = IndexKeyTracer.trace(leftIndexCandidate.get(), ImmutableSet.copyOf(leftJoinVariables)); + checkState(!trace.isEmpty() && leftJoinVariables.containsAll(trace.keySet())); } Optional rightIndexCandidate = IndexSourceRewriter.rewriteWithIndex( rightRewritten, - ImmutableSet.copyOf(rightJoinSymbols), + ImmutableSet.copyOf(rightJoinVariables), symbolAllocator, idAllocator, metadata, session); if (rightIndexCandidate.isPresent()) { // Sanity check that we can trace the path for the index lookup key - Map trace = IndexKeyTracer.trace(rightIndexCandidate.get(), ImmutableSet.copyOf(rightJoinSymbols)); - checkState(!trace.isEmpty() && rightJoinSymbols.containsAll(trace.keySet())); + Map trace = IndexKeyTracer.trace(rightIndexCandidate.get(), ImmutableSet.copyOf(rightJoinVariables)); + checkState(!trace.isEmpty() && rightJoinVariables.containsAll(trace.keySet())); } switch (node.getType()) { @@ -147,10 +147,10 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // Prefer the right candidate over the left candidate PlanNode indexJoinNode = null; if (rightIndexCandidate.isPresent()) { - indexJoinNode = new IndexJoinNode(idAllocator.getNextId(), IndexJoinNode.Type.INNER, leftRewritten, rightIndexCandidate.get(), createEquiJoinClause(leftJoinSymbols, rightJoinSymbols), Optional.empty(), Optional.empty()); + indexJoinNode = new IndexJoinNode(idAllocator.getNextId(), IndexJoinNode.Type.INNER, leftRewritten, rightIndexCandidate.get(), createEquiJoinClause(leftJoinVariables, rightJoinVariables), Optional.empty(), Optional.empty()); } else if (leftIndexCandidate.isPresent()) { - indexJoinNode = new IndexJoinNode(idAllocator.getNextId(), IndexJoinNode.Type.INNER, rightRewritten, leftIndexCandidate.get(), createEquiJoinClause(rightJoinSymbols, leftJoinSymbols), Optional.empty(), Optional.empty()); + indexJoinNode = new IndexJoinNode(idAllocator.getNextId(), IndexJoinNode.Type.INNER, rightRewritten, leftIndexCandidate.get(), createEquiJoinClause(rightJoinVariables, leftJoinVariables), Optional.empty(), Optional.empty()); } if (indexJoinNode != null) { @@ -158,11 +158,11 @@ else if (leftIndexCandidate.isPresent()) { indexJoinNode = new FilterNode(idAllocator.getNextId(), indexJoinNode, node.getFilter().get()); } - if (!indexJoinNode.getOutputSymbols().equals(node.getOutputSymbols())) { + if (!indexJoinNode.getOutputVariables().equals(node.getOutputVariables())) { indexJoinNode = new ProjectNode( idAllocator.getNextId(), indexJoinNode, - Assignments.identity(node.getOutputSymbols())); + Assignments.identity(node.getOutputVariables())); } return indexJoinNode; @@ -172,14 +172,14 @@ else if (leftIndexCandidate.isPresent()) { case LEFT: // We cannot use indices for outer joins until index join supports in-line filtering if (!node.getFilter().isPresent() && rightIndexCandidate.isPresent()) { - return createIndexJoinWithExpectedOutputs(node.getOutputSymbols(), IndexJoinNode.Type.SOURCE_OUTER, leftRewritten, rightIndexCandidate.get(), createEquiJoinClause(leftJoinSymbols, rightJoinSymbols), idAllocator); + return createIndexJoinWithExpectedOutputs(node.getOutputVariables(), IndexJoinNode.Type.SOURCE_OUTER, leftRewritten, rightIndexCandidate.get(), createEquiJoinClause(leftJoinVariables, rightJoinVariables), idAllocator, symbolAllocator); } break; case RIGHT: // We cannot use indices for outer joins until index join supports in-line filtering if (!node.getFilter().isPresent() && leftIndexCandidate.isPresent()) { - return createIndexJoinWithExpectedOutputs(node.getOutputSymbols(), IndexJoinNode.Type.SOURCE_OUTER, rightRewritten, leftIndexCandidate.get(), createEquiJoinClause(rightJoinSymbols, leftJoinSymbols), idAllocator); + return createIndexJoinWithExpectedOutputs(node.getOutputVariables(), IndexJoinNode.Type.SOURCE_OUTER, rightRewritten, leftIndexCandidate.get(), createEquiJoinClause(rightJoinVariables, leftJoinVariables), idAllocator, symbolAllocator); } break; @@ -192,15 +192,22 @@ else if (leftIndexCandidate.isPresent()) { } if (leftRewritten != node.getLeft() || rightRewritten != node.getRight()) { - return new JoinNode(node.getId(), node.getType(), leftRewritten, rightRewritten, node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); + return new JoinNode(node.getId(), node.getType(), leftRewritten, rightRewritten, node.getCriteria(), node.getOutputVariables(), node.getFilter(), node.getLeftHashVariable(), node.getRightHashVariable(), node.getDistributionType()); } return node; } - private static PlanNode createIndexJoinWithExpectedOutputs(List expectedOutputs, IndexJoinNode.Type type, PlanNode probe, PlanNode index, List equiJoinClause, PlanNodeIdAllocator idAllocator) + private static PlanNode createIndexJoinWithExpectedOutputs( + List expectedOutputs, + IndexJoinNode.Type type, + PlanNode probe, + PlanNode index, + List equiJoinClause, + PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator) { PlanNode result = new IndexJoinNode(idAllocator.getNextId(), type, probe, index, equiJoinClause, Optional.empty(), Optional.empty()); - if (!result.getOutputSymbols().equals(expectedOutputs)) { + if (!result.getOutputVariables().equals(expectedOutputs)) { result = new ProjectNode( idAllocator.getNextId(), result, @@ -209,12 +216,12 @@ private static PlanNode createIndexJoinWithExpectedOutputs(List expected return result; } - private static List createEquiJoinClause(List probeSymbols, List indexSymbols) + private static List createEquiJoinClause(List probeVariables, List indexVariables) { - checkArgument(probeSymbols.size() == indexSymbols.size()); + checkArgument(probeVariables.size() == indexVariables.size()); ImmutableList.Builder builder = ImmutableList.builder(); - for (int i = 0; i < probeSymbols.size(); i++) { - builder.add(new IndexJoinNode.EquiJoinClause(probeSymbols.get(i), indexSymbols.get(i))); + for (int i = 0; i < probeVariables.size(); i++) { + builder.add(new IndexJoinNode.EquiJoinClause(probeVariables.get(i), indexVariables.get(i))); } return builder.build(); } @@ -243,7 +250,7 @@ private IndexSourceRewriter(SymbolAllocator symbolAllocator, PlanNodeIdAllocator public static Optional rewriteWithIndex( PlanNode planNode, - Set lookupSymbols, + Set lookupVariables, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Metadata metadata, @@ -251,7 +258,7 @@ public static Optional rewriteWithIndex( { AtomicBoolean success = new AtomicBoolean(); IndexSourceRewriter indexSourceRewriter = new IndexSourceRewriter(symbolAllocator, idAllocator, metadata, session); - PlanNode rewritten = SimplePlanRewriter.rewriteWith(indexSourceRewriter, planNode, new Context(lookupSymbols, success)); + PlanNode rewritten = SimplePlanRewriter.rewriteWith(indexSourceRewriter, planNode, new Context(lookupVariables, success)); if (success.get()) { return Optional.of(rewritten); } @@ -280,16 +287,16 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context symbolAllocator.getTypes()); TupleDomain simplifiedConstraint = decomposedPredicate.getTupleDomain() - .transform(node.getAssignments()::get) + .transform(symbol -> node.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> new Symbol(entry.getKey().getName()), Map.Entry::getValue)).get(symbol)) .intersect(node.getEnforcedConstraint()); - checkState(node.getOutputSymbols().containsAll(context.getLookupSymbols())); + checkState(node.getOutputVariables().containsAll(context.getLookupVariables())); - Set lookupColumns = context.getLookupSymbols().stream() - .map(node.getAssignments()::get) + Set lookupColumns = context.getLookupVariables().stream() + .map(variable -> node.getAssignments().get(variable)) .collect(toImmutableSet()); - Set outputColumns = node.getOutputSymbols().stream().map(node.getAssignments()::get).collect(toImmutableSet()); + Set outputColumns = node.getOutputVariables().stream().map(node.getAssignments()::get).collect(toImmutableSet()); Optional optionalResolvedIndex = metadata.resolveIndex(session, node.getTable(), lookupColumns, outputColumns, simplifiedConstraint); if (!optionalResolvedIndex.isPresent()) { @@ -298,14 +305,15 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context } ResolvedIndex resolvedIndex = optionalResolvedIndex.get(); - Map inverseAssignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); + Map inverseAssignments = node.getAssignments().entrySet().stream() + .collect(toImmutableMap(Map.Entry::getValue, entry -> new Symbol(entry.getKey().getName()))); PlanNode source = new IndexSourceNode( idAllocator.getNextId(), resolvedIndex.getIndexHandle(), node.getTable(), - context.getLookupSymbols(), - node.getOutputSymbols(), + context.getLookupVariables(), + node.getOutputVariables(), node.getAssignments(), simplifiedConstraint); @@ -324,18 +332,21 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context @Override public PlanNode visitProject(ProjectNode node, RewriteContext context) { - // Rewrite the lookup symbols in terms of only the pre-projected symbols that have direct translations - Set newLookupSymbols = context.get().getLookupSymbols().stream() - .map(node.getAssignments()::get) - .filter(SymbolReference.class::isInstance) - .map(Symbol::from) - .collect(toImmutableSet()); + // Rewrite the lookup variables in terms of only the pre-projected variables that have direct translations + ImmutableSet.Builder newLookupVariablesBuilder = ImmutableSet.builder(); + for (VariableReferenceExpression variable : context.get().getLookupVariables()) { + Expression expression = node.getAssignments().get(variable); + if (expression instanceof SymbolReference) { + newLookupVariablesBuilder.add(new VariableReferenceExpression(((SymbolReference) expression).getName(), variable.getType())); + } + } + ImmutableSet newLookupVariables = newLookupVariablesBuilder.build(); - if (newLookupSymbols.isEmpty()) { + if (newLookupVariables.isEmpty()) { return node; } - return context.defaultRewrite(node, new Context(newLookupSymbols, context.get().getSuccess())); + return context.defaultRewrite(node, new Context(newLookupVariables, context.get().getSuccess())); } @Override @@ -345,7 +356,7 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) return planTableScan((TableScanNode) node.getSource(), castToExpression(node.getPredicate()), context.get()); } - return context.defaultRewrite(node, new Context(context.get().getLookupSymbols(), context.get().getSuccess())); + return context.defaultRewrite(node, new Context(context.get().getLookupVariables(), context.get().getSuccess())); } @Override @@ -356,7 +367,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) return node; } - // Don't need this restriction if we can prove that all order by symbols are deterministically produced + // Don't need this restriction if we can prove that all order by variables are deterministically produced if (node.getOrderingScheme().isPresent()) { return node; } @@ -368,16 +379,16 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) return node; } - // Lookup symbols can only be passed through if they are part of the partitioning - Set partitionByLookupSymbols = context.get().getLookupSymbols().stream() + // Lookup variables can only be passed through if they are part of the partitioning + Set partitionByLookupVariables = context.get().getLookupVariables().stream() .filter(node.getPartitionBy()::contains) .collect(toImmutableSet()); - if (partitionByLookupSymbols.isEmpty()) { + if (partitionByLookupVariables.isEmpty()) { return node; } - return context.defaultRewrite(node, new Context(partitionByLookupSymbols, context.get().getSuccess())); + return context.defaultRewrite(node, new Context(partitionByLookupVariables, context.get().getSuccess())); } @Override @@ -389,20 +400,20 @@ public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext c @Override public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) { - // Lookup symbols can only be passed through the probe side of an index join - Set probeLookupSymbols = context.get().getLookupSymbols().stream() - .filter(node.getProbeSource().getOutputSymbols()::contains) + // Lookup variables can only be passed through the probe side of an index join + Set probeLookupVariables = context.get().getLookupVariables().stream() + .filter(node.getProbeSource().getOutputVariables()::contains) .collect(toImmutableSet()); - if (probeLookupSymbols.isEmpty()) { + if (probeLookupVariables.isEmpty()) { return node; } - PlanNode rewrittenProbeSource = context.rewrite(node.getProbeSource(), new Context(probeLookupSymbols, context.get().getSuccess())); + PlanNode rewrittenProbeSource = context.rewrite(node.getProbeSource(), new Context(probeLookupVariables, context.get().getSuccess())); PlanNode source = node; if (rewrittenProbeSource != node.getProbeSource()) { - source = new IndexJoinNode(node.getId(), node.getType(), rewrittenProbeSource, node.getIndexSource(), node.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol()); + source = new IndexJoinNode(node.getId(), node.getType(), rewrittenProbeSource, node.getIndexSource(), node.getCriteria(), node.getProbeHashVariable(), node.getIndexHashVariable()); } return source; @@ -411,16 +422,16 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext conte @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { - // Lookup symbols can only be passed through if they are part of the group by columns - Set groupByLookupSymbols = context.get().getLookupSymbols().stream() + // Lookup variables can only be passed through if they are part of the group by columns + Set groupByLookupVariables = context.get().getLookupVariables().stream() .filter(node.getGroupingKeys()::contains) .collect(toImmutableSet()); - if (groupByLookupSymbols.isEmpty()) { + if (groupByLookupVariables.isEmpty()) { return node; } - return context.defaultRewrite(node, new Context(groupByLookupSymbols, context.get().getSuccess())); + return context.defaultRewrite(node, new Context(groupByLookupVariables, context.get().getSuccess())); } @Override @@ -432,19 +443,20 @@ public PlanNode visitSort(SortNode node, RewriteContext context) public static class Context { - private final Set lookupSymbols; + private final Set lookupVariables; private final AtomicBoolean success; - public Context(Set lookupSymbols, AtomicBoolean success) + public Context(Set lookupVariables, AtomicBoolean success) { - checkArgument(!lookupSymbols.isEmpty(), "lookupSymbols can not be empty"); - this.lookupSymbols = ImmutableSet.copyOf(requireNonNull(lookupSymbols, "lookupSymbols is null")); + requireNonNull(lookupVariables, "lookupVariables is null"); + checkArgument(!lookupVariables.isEmpty(), "lookupVariables can not be empty"); + this.lookupVariables = ImmutableSet.copyOf(lookupVariables); this.success = requireNonNull(success, "success is null"); } - public Set getLookupSymbols() + public Set getLookupVariables() { - return lookupSymbols; + return lookupVariables; } public AtomicBoolean getSuccess() @@ -460,93 +472,93 @@ public void markSuccess() } /** - * Identify the mapping from the lookup symbols used at the top of the index plan to - * the actual symbols produced by the IndexSource. Note that multiple top-level lookup symbols may share the same - * underlying IndexSource symbol. Also note that lookup symbols that do not correspond to underlying index source symbols + * Identify the mapping from the lookup variables used at the top of the index plan to + * the actual variables produced by the IndexSource. Note that multiple top-level lookup variables may share the same + * underlying IndexSource symbol. Also note that lookup variables that do not correspond to underlying index source variables * will be omitted from the returned Map. */ public static class IndexKeyTracer { - public static Map trace(PlanNode node, Set lookupSymbols) + public static Map trace(PlanNode node, Set lookupVariables) { - return node.accept(new Visitor(), lookupSymbols); + return node.accept(new Visitor(), lookupVariables); } private static class Visitor - extends InternalPlanVisitor, Set> + extends InternalPlanVisitor, Set> { @Override - protected Map visitPlan(PlanNode node, Set lookupSymbols) + protected Map visitPlan(PlanNode node, Set lookupVariables) { throw new UnsupportedOperationException("Node not expected to be part of Index pipeline: " + node); } @Override - public Map visitProject(ProjectNode node, Set lookupSymbols) + public Map visitProject(ProjectNode node, Set lookupVariables) { - // Map from output Symbols to source Symbols - Map directSymbolTranslationOutputMap = Maps.transformValues(Maps.filterValues(node.getAssignments().getMap(), SymbolReference.class::isInstance), Symbol::from); - Map outputToSourceMap = lookupSymbols.stream() + // Map from output variables to source variables + Map directSymbolTranslationOutputMap = Maps.transformValues(Maps.filterValues(node.getAssignments().getMap(), SymbolReference.class::isInstance), Symbol::from); + Map outputToSourceMap = lookupVariables.stream() .filter(directSymbolTranslationOutputMap.keySet()::contains) - .collect(toImmutableMap(identity(), directSymbolTranslationOutputMap::get)); + .collect(toImmutableMap(identity(), variable -> new VariableReferenceExpression(directSymbolTranslationOutputMap.get(variable).getName(), variable.getType()))); - checkState(!outputToSourceMap.isEmpty(), "No lookup symbols were able to pass through the projection"); + checkState(!outputToSourceMap.isEmpty(), "No lookup variables were able to pass through the projection"); - // Map from source Symbols to underlying index source Symbols - Map sourceToIndexMap = node.getSource().accept(this, ImmutableSet.copyOf(outputToSourceMap.values())); + // Map from source variables to underlying index source variables + Map sourceToIndexMap = node.getSource().accept(this, ImmutableSet.copyOf(outputToSourceMap.values())); - // Generate the Map the connects lookup symbols to underlying index source symbols - Map outputToIndexMap = Maps.transformValues(Maps.filterValues(outputToSourceMap, in(sourceToIndexMap.keySet())), Functions.forMap(sourceToIndexMap)); + // Generate the Map the connects lookup variables to underlying index source variables + Map outputToIndexMap = Maps.transformValues(Maps.filterValues(outputToSourceMap, in(sourceToIndexMap.keySet())), Functions.forMap(sourceToIndexMap)); return ImmutableMap.copyOf(outputToIndexMap); } @Override - public Map visitFilter(FilterNode node, Set lookupSymbols) + public Map visitFilter(FilterNode node, Set lookupVariables) { - return node.getSource().accept(this, lookupSymbols); + return node.getSource().accept(this, lookupVariables); } @Override - public Map visitWindow(WindowNode node, Set lookupSymbols) + public Map visitWindow(WindowNode node, Set lookupVariables) { - Set partitionByLookupSymbols = lookupSymbols.stream() + Set partitionByLookupVariables = lookupVariables.stream() .filter(node.getPartitionBy()::contains) .collect(toImmutableSet()); - checkState(!partitionByLookupSymbols.isEmpty(), "No lookup symbols were able to pass through the aggregation group by"); - return node.getSource().accept(this, partitionByLookupSymbols); + checkState(!partitionByLookupVariables.isEmpty(), "No lookup variables were able to pass through the aggregation group by"); + return node.getSource().accept(this, partitionByLookupVariables); } @Override - public Map visitIndexJoin(IndexJoinNode node, Set lookupSymbols) + public Map visitIndexJoin(IndexJoinNode node, Set lookupVariables) { - Set probeLookupSymbols = lookupSymbols.stream() - .filter(node.getProbeSource().getOutputSymbols()::contains) + Set probeLookupVariables = lookupVariables.stream() + .filter(node.getProbeSource().getOutputVariables()::contains) .collect(toImmutableSet()); - checkState(!probeLookupSymbols.isEmpty(), "No lookup symbols were able to pass through the index join probe source"); - return node.getProbeSource().accept(this, probeLookupSymbols); + checkState(!probeLookupVariables.isEmpty(), "No lookup variables were able to pass through the index join probe source"); + return node.getProbeSource().accept(this, probeLookupVariables); } @Override - public Map visitAggregation(AggregationNode node, Set lookupSymbols) + public Map visitAggregation(AggregationNode node, Set lookupVariables) { - Set groupByLookupSymbols = lookupSymbols.stream() + Set groupByLookupVariables = lookupVariables.stream() .filter(node.getGroupingKeys()::contains) .collect(toImmutableSet()); - checkState(!groupByLookupSymbols.isEmpty(), "No lookup symbols were able to pass through the aggregation group by"); - return node.getSource().accept(this, groupByLookupSymbols); + checkState(!groupByLookupVariables.isEmpty(), "No lookup variables were able to pass through the aggregation group by"); + return node.getSource().accept(this, groupByLookupVariables); } @Override - public Map visitSort(SortNode node, Set lookupSymbols) + public Map visitSort(SortNode node, Set lookupVariables) { - return node.getSource().accept(this, lookupSymbols); + return node.getSource().accept(this, lookupVariables); } @Override - public Map visitIndexSource(IndexSourceNode node, Set lookupSymbols) + public Map visitIndexSource(IndexSourceNode node, Set lookupVariables) { - checkState(node.getLookupSymbols().equals(lookupSymbols), "lookupSymbols must be the same as IndexSource lookup symbols"); - return lookupSymbols.stream().collect(toImmutableMap(identity(), identity())); + checkState(node.getLookupVariables().equals(lookupVariables), "lookupVariables must be the same as IndexSource lookup variables"); + return lookupVariables.stream().collect(toImmutableMap(identity(), identity())); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinNodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinNodeUtils.java index 491376d0c8fba..8f4c1458e9976 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinNodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinNodeUtils.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Join; +import com.facebook.presto.sql.tree.SymbolReference; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; @@ -30,7 +31,7 @@ private JoinNodeUtils() {} public static ComparisonExpression toExpression(EquiJoinClause clause) { - return new ComparisonExpression(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference()); + return new ComparisonExpression(EQUAL, new SymbolReference(clause.getLeft().getName()), new SymbolReference(clause.getRight().getName())); } public static JoinNode.Type typeConvert(Join.Type joinType) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LimitPushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LimitPushDown.java index dc2b0372a0d5a..8a02c627b5c25 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LimitPushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LimitPushDown.java @@ -119,7 +119,7 @@ public PlanNode visitLimit(LimitNode node, RewriteContext context) // return empty ValuesNode in case of limit 0 if (count == 0) { return new ValuesNode(idAllocator.getNextId(), - node.getOutputSymbols(), + node.getOutputVariables(), ImmutableList.of()); } @@ -135,10 +135,10 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext context) sources.add(context.rewrite(node.getSources().get(i), childLimit)); } - PlanNode output = new UnionNode(node.getId(), sources, node.getSymbolMapping(), node.getOutputSymbols()); + PlanNode output = new UnionNode(node.getId(), sources, node.getVariableMapping()); if (limit != null) { output = new LimitNode(idAllocator.getNextId(), output, limit.getCount(), limit.isPartial()); } @@ -228,11 +228,11 @@ public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext co node.getId(), source, node.getFilteringSource(), - node.getSourceJoinSymbol(), - node.getFilteringSourceJoinSymbol(), + node.getSourceJoinVariable(), + node.getFilteringSourceJoinVariable(), node.getSemiJoinOutput(), - node.getSourceHashSymbol(), - node.getFilteringSourceHashSymbol(), + node.getSourceHashVariable(), + node.getFilteringSourceHashVariable(), node.getDistributionType()); } return node; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataDeleteOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataDeleteOptimizer.java index 9a94c092ea42f..f8ed16841dc11 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataDeleteOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataDeleteOptimizer.java @@ -97,7 +97,7 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont return new MetadataDeleteNode( idAllocator.getNextId(), new DeleteHandle(tableScanNode.getTable(), delete.get().getTarget().getSchemaTableName()), - Iterables.getOnlyElement(node.getOutputSymbols())); + Iterables.getOnlyElement(node.getOutputVariables())); } private static Optional findNode(PlanNode source, Class clazz) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java index eaff22718dd46..5ef7d16d95a2d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -19,17 +19,15 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.spi.ColumnHandle; -import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.DiscretePredicates; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator; import com.facebook.presto.sql.planner.LiteralEncoder; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -121,20 +119,15 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont // verify all outputs of table scan are partition keys TableScanNode tableScan = result.get(); - ImmutableMap.Builder typesBuilder = ImmutableMap.builder(); - ImmutableMap.Builder columnBuilder = ImmutableMap.builder(); + ImmutableMap.Builder columnBuilder = ImmutableMap.builder(); - List inputs = tableScan.getOutputSymbols(); - for (Symbol symbol : inputs) { - ColumnHandle column = tableScan.getAssignments().get(symbol); - ColumnMetadata columnMetadata = metadata.getColumnMetadata(session, tableScan.getTable(), column); - - typesBuilder.put(symbol, columnMetadata.getType()); - columnBuilder.put(symbol, column); + List inputs = tableScan.getOutputVariables(); + for (VariableReferenceExpression variable : inputs) { + ColumnHandle column = tableScan.getAssignments().get(variable); + columnBuilder.put(variable, column); } - Map columns = columnBuilder.build(); - Map types = typesBuilder.build(); + Map columns = columnBuilder.build(); // Materialize the list of partitions and replace the TableScan node // with a Values node @@ -163,16 +156,15 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont ImmutableList.Builder rowBuilder = ImmutableList.builder(); // for each input column, add a literal expression using the entry value - for (Symbol input : inputs) { + for (VariableReferenceExpression input : inputs) { ColumnHandle column = columns.get(input); - Type type = types.get(input); NullableValue value = entries.get(column); if (value == null) { // partition key does not have a single value, so bail out to be safe return context.defaultRewrite(node); } else { - rowBuilder.add(constant(value.getValue(), type)); + rowBuilder.add(constant(value.getValue(), input.getType())); } } rowsBuilder.add(rowBuilder.build()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 5f0530dd3e915..44195a4d960b8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -17,6 +17,7 @@ import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; @@ -38,6 +39,7 @@ import com.facebook.presto.sql.tree.IfExpression; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -56,6 +58,7 @@ import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; /* * This optimizer convert query of form: @@ -110,9 +113,9 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext masks = node.getAggregations().values().stream() + List masks = node.getAggregations().values().stream() .map(Aggregation::getMask).filter(Optional::isPresent).map(Optional::get).collect(toImmutableList()); - Set uniqueMasks = ImmutableSet.copyOf(masks); + Set uniqueMasks = ImmutableSet.copyOf(masks); if (uniqueMasks.size() != 1 || masks.size() == node.getAggregations().size()) { return context.defaultRewrite(node, Optional.empty()); } @@ -130,7 +133,8 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext aggregations = ImmutableMap.builder(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); // Add coalesce projection node to handle count(), count_if(), approx_distinct() functions return a // non-null result without any input - ImmutableMap.Builder coalesceSymbolsBuilder = ImmutableMap.builder(); - for (Map.Entry entry : node.getAggregations().entrySet()) { + ImmutableMap.Builder coalesceVariablesBuilder = ImmutableMap.builder(); + for (Map.Entry entry : node.getAggregations().entrySet()) { if (entry.getValue().getMask().isPresent()) { aggregations.put(entry.getKey(), new Aggregation( entry.getValue().getFunctionHandle(), - ImmutableList.of(aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference()), + ImmutableList.of(new SymbolReference(aggregateInfo.getNewDistinctAggregateVariable().getName())), Optional.empty(), Optional.empty(), false, @@ -164,26 +168,26 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext coalesceSymbols = coalesceSymbolsBuilder.build(); + Map coalesceVariables = coalesceVariablesBuilder.build(); AggregationNode aggregationNode = new AggregationNode( idAllocator.getNextId(), @@ -193,24 +197,24 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext aggregateInfo = context.get(); // presence of aggregateInfo => mask also present - if (!aggregateInfo.isPresent() || !aggregateInfo.get().getMask().equals(node.getMarkerSymbol())) { + if (!aggregateInfo.isPresent() || !aggregateInfo.get().getMask().equals(node.getMarkerVariable())) { return context.defaultRewrite(node, Optional.empty()); } @@ -227,61 +231,61 @@ public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext allSymbols = new HashSet<>(); - List groupBySymbols = aggregateInfo.get().getGroupBySymbols(); // a - List nonDistinctAggregateSymbols = aggregateInfo.get().getOriginalNonDistinctAggregateArgs(); //b - Symbol distinctSymbol = Iterables.getOnlyElement(aggregateInfo.get().getOriginalDistinctAggregateArgs()); // c + Set allVariables = new HashSet<>(); + List groupByVariables = aggregateInfo.get().getGroupByVariables(); // a + List nonDistinctAggregateVariables = aggregateInfo.get().getOriginalNonDistinctAggregateArgs(); //b + VariableReferenceExpression distinctVariable = Iterables.getOnlyElement(aggregateInfo.get().getOriginalDistinctAggregateArgs()); // c // If same symbol present in aggregations on distinct and non-distinct values, e.g. select sum(a), count(distinct a), // then we need to create a duplicate stream for this symbol - Symbol duplicatedDistinctSymbol = distinctSymbol; + VariableReferenceExpression duplicatedDistinctVariable = distinctVariable; - if (nonDistinctAggregateSymbols.contains(distinctSymbol)) { - Symbol newSymbol = symbolAllocator.newSymbol(distinctSymbol.getName(), symbolAllocator.getTypes().get(distinctSymbol)); - nonDistinctAggregateSymbols.set(nonDistinctAggregateSymbols.indexOf(distinctSymbol), newSymbol); - duplicatedDistinctSymbol = newSymbol; + if (nonDistinctAggregateVariables.contains(distinctVariable)) { + VariableReferenceExpression newVariable = symbolAllocator.newVariable(distinctVariable); + nonDistinctAggregateVariables.set(nonDistinctAggregateVariables.indexOf(distinctVariable), newVariable); + duplicatedDistinctVariable = newVariable; } - allSymbols.addAll(groupBySymbols); - allSymbols.addAll(nonDistinctAggregateSymbols); - allSymbols.add(distinctSymbol); + allVariables.addAll(groupByVariables); + allVariables.addAll(nonDistinctAggregateVariables); + allVariables.add(distinctVariable); // 1. Add GroupIdNode - Symbol groupSymbol = symbolAllocator.newSymbol("group", BigintType.BIGINT); // g + VariableReferenceExpression groupVariable = symbolAllocator.newVariable("group", BigintType.BIGINT); // g GroupIdNode groupIdNode = createGroupIdNode( - groupBySymbols, - nonDistinctAggregateSymbols, - distinctSymbol, - duplicatedDistinctSymbol, - groupSymbol, - allSymbols, + groupByVariables, + nonDistinctAggregateVariables, + distinctVariable, + duplicatedDistinctVariable, + groupVariable, + allVariables, source); // 2. Add aggregation node - Set groupByKeys = new HashSet<>(groupBySymbols); - groupByKeys.add(distinctSymbol); - groupByKeys.add(groupSymbol); + Set groupByKeys = new HashSet<>(groupByVariables); + groupByKeys.add(distinctVariable); + groupByKeys.add(groupVariable); - ImmutableMap.Builder aggregationOutputSymbolsMapBuilder = ImmutableMap.builder(); + ImmutableMap.Builder aggregationOutputVariablesMapBuilder = ImmutableMap.builder(); AggregationNode aggregationNode = createNonDistinctAggregation( aggregateInfo.get(), - distinctSymbol, - duplicatedDistinctSymbol, + distinctVariable, + duplicatedDistinctVariable, groupByKeys, groupIdNode, node, - aggregationOutputSymbolsMapBuilder); + aggregationOutputVariablesMapBuilder); // This map has mapping only for aggregation on non-distinct symbols which the new AggregationNode handles - Map aggregationOutputSymbolsMap = aggregationOutputSymbolsMapBuilder.build(); + Map aggregationOutputVariablesMap = aggregationOutputVariablesMapBuilder.build(); // 3. Add new project node that adds if expressions ProjectNode projectNode = createProjectNode( aggregationNode, aggregateInfo.get(), - distinctSymbol, - groupSymbol, - groupBySymbols, - aggregationOutputSymbolsMap); + distinctVariable, + groupVariable, + groupByVariables, + aggregationOutputVariablesMap); return projectNode; } @@ -289,14 +293,13 @@ public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext groupBySymbols, - Map aggregationOutputSymbolsMap) + VariableReferenceExpression distinctVariable, + VariableReferenceExpression groupVariable, + List groupByVariables, + Map aggregationOutputVariablesMap) { - Assignments.Builder outputSymbols = Assignments.builder(); - ImmutableMap.Builder outputNonDistinctAggregateSymbols = ImmutableMap.builder(); - for (Symbol symbol : source.getOutputSymbols()) { - if (distinctSymbol.equals(symbol)) { - Symbol newSymbol = symbolAllocator.newSymbol("expr", symbolAllocator.getTypes().get(symbol)); - aggregateInfo.setNewDistinctAggregateSymbol(newSymbol); + Assignments.Builder outputVariables = Assignments.builder(); + ImmutableMap.Builder outputNonDistinctAggregateVariables = ImmutableMap.builder(); + for (VariableReferenceExpression variable : source.getOutputVariables()) { + if (distinctVariable.equals(variable)) { + VariableReferenceExpression newVariable = symbolAllocator.newVariable("expr", variable.getType()); + aggregateInfo.setNewDistinctAggregateSymbol(newVariable); Expression expression = createIfExpression( - groupSymbol.toSymbolReference(), + new SymbolReference(groupVariable.getName()), new Cast(new LongLiteral("1"), "bigint"), // TODO: this should use GROUPING() when that's available instead of relying on specific group numbering ComparisonExpression.Operator.EQUAL, - symbol.toSymbolReference(), - symbolAllocator.getTypes().get(symbol)); - outputSymbols.put(newSymbol, expression); + new SymbolReference(variable.getName()), + variable.getType()); + outputVariables.put(newVariable, expression); } - else if (aggregationOutputSymbolsMap.containsKey(symbol)) { - Symbol newSymbol = symbolAllocator.newSymbol("expr", symbolAllocator.getTypes().get(symbol)); + else if (aggregationOutputVariablesMap.containsKey(variable)) { + VariableReferenceExpression newVariable = symbolAllocator.newVariable("expr", variable.getType()); // key of outputNonDistinctAggregateSymbols is key of an aggregation in AggrNode above, it will now aggregate on this Map's value - outputNonDistinctAggregateSymbols.put(aggregationOutputSymbolsMap.get(symbol), newSymbol); + outputNonDistinctAggregateVariables.put(aggregationOutputVariablesMap.get(variable), newVariable); Expression expression = createIfExpression( - groupSymbol.toSymbolReference(), + new SymbolReference(groupVariable.getName()), new Cast(new LongLiteral("0"), "bigint"), // TODO: this should use GROUPING() when that's available instead of relying on specific group numbering ComparisonExpression.Operator.EQUAL, - symbol.toSymbolReference(), - symbolAllocator.getTypes().get(symbol)); - outputSymbols.put(newSymbol, expression); + new SymbolReference(variable.getName()), + variable.getType()); + outputVariables.put(newVariable, expression); } // A symbol can appear both in groupBy and distinct/non-distinct aggregation - if (groupBySymbols.contains(symbol)) { - Expression expression = symbol.toSymbolReference(); - outputSymbols.put(symbol, expression); + if (groupByVariables.contains(variable)) { + Expression expression = new SymbolReference(variable.getName()); + outputVariables.put(variable, expression); } } // add null assignment for mask // unused mask will be removed by PruneUnreferencedOutputs - outputSymbols.put(aggregateInfo.getMask(), new NullLiteral()); + outputVariables.put(aggregateInfo.getMask(), new NullLiteral()); - aggregateInfo.setNewNonDistinctAggregateSymbols(outputNonDistinctAggregateSymbols.build()); + aggregateInfo.setNewNonDistinctAggregateSymbols(outputNonDistinctAggregateVariables.build()); - return new ProjectNode(idAllocator.getNextId(), source, outputSymbols.build()); + return new ProjectNode(idAllocator.getNextId(), source, outputVariables.build()); } private GroupIdNode createGroupIdNode( - List groupBySymbols, - List nonDistinctAggregateSymbols, - Symbol distinctSymbol, - Symbol duplicatedDistinctSymbol, - Symbol groupSymbol, - Set allSymbols, + List groupByVariables, + List nonDistinctAggregateVariables, + VariableReferenceExpression distinctVariable, + VariableReferenceExpression duplicatedDistinctVariable, + VariableReferenceExpression groupVariable, + Set allVariables, PlanNode source) { - List> groups = new ArrayList<>(); + List> groups = new ArrayList<>(); // g0 = {group-by symbols + allNonDistinctAggregateSymbols} // g1 = {group-by symbols + Distinct Symbol} // symbols present in Group_i will be set, rest will be Null //g0 - Set group0 = new HashSet<>(); - group0.addAll(groupBySymbols); - group0.addAll(nonDistinctAggregateSymbols); + Set group0 = new HashSet<>(); + group0.addAll(groupByVariables); + group0.addAll(nonDistinctAggregateVariables); groups.add(ImmutableList.copyOf(group0)); // g1 - Set group1 = new HashSet<>(groupBySymbols); - group1.add(distinctSymbol); + Set group1 = new HashSet<>(groupByVariables); + group1.add(distinctVariable); groups.add(ImmutableList.copyOf(group1)); return new GroupIdNode( idAllocator.getNextId(), source, groups, - allSymbols.stream().collect(Collectors.toMap( - symbol -> symbol, - symbol -> (symbol.equals(duplicatedDistinctSymbol) ? distinctSymbol : symbol))), + allVariables.stream().collect(Collectors.toMap( + identity(), + variable -> (variable.equals(duplicatedDistinctVariable) ? distinctVariable : variable))), ImmutableList.of(), - groupSymbol); + groupVariable); } /* @@ -412,27 +415,27 @@ private GroupIdNode createGroupIdNode( */ private AggregationNode createNonDistinctAggregation( AggregateInfo aggregateInfo, - Symbol distinctSymbol, - Symbol duplicatedDistinctSymbol, - Set groupByKeys, + VariableReferenceExpression distinctVariable, + VariableReferenceExpression duplicatedDistinctVariable, + Set groupByKeys, GroupIdNode groupIdNode, MarkDistinctNode originalNode, - ImmutableMap.Builder aggregationOutputSymbolsMapBuilder) + ImmutableMap.Builder aggregationOutputSymbolsMapBuilder) { - ImmutableMap.Builder aggregations = ImmutableMap.builder(); - for (Map.Entry entry : aggregateInfo.getAggregations().entrySet()) { + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : aggregateInfo.getAggregations().entrySet()) { if (!entry.getValue().getMask().isPresent()) { - Symbol newSymbol = symbolAllocator.newSymbol(entry.getKey().toSymbolReference(), symbolAllocator.getTypes().get(entry.getKey())); + VariableReferenceExpression newVariable = symbolAllocator.newVariable(entry.getKey()); Aggregation aggregation = entry.getValue(); - aggregationOutputSymbolsMapBuilder.put(newSymbol, entry.getKey()); + aggregationOutputSymbolsMapBuilder.put(newVariable, entry.getKey()); // Handling for cases when mask symbol appears in non distinct aggregations too // Now the aggregation should happen over the duplicate symbol added before List arguments; - if (!duplicatedDistinctSymbol.equals(distinctSymbol) && entry.getValue().getArguments().contains(distinctSymbol.toSymbolReference())) { + if (!duplicatedDistinctVariable.equals(distinctVariable) && entry.getValue().getArguments().contains(new SymbolReference(distinctVariable.getName()))) { ImmutableList.Builder argumentsBuilder = ImmutableList.builder(); for (Expression argument : aggregation.getArguments()) { - if (distinctSymbol.toSymbolReference().equals(argument)) { - argumentsBuilder.add(duplicatedDistinctSymbol.toSymbolReference()); + if (new SymbolReference(distinctVariable.getName()).equals(argument)) { + argumentsBuilder.add(new SymbolReference(duplicatedDistinctVariable.getName())); } else { argumentsBuilder.add(argument); @@ -443,7 +446,7 @@ private AggregationNode createNonDistinctAggregation( else { arguments = entry.getValue().getArguments(); } - aggregations.put(newSymbol, new Aggregation( + aggregations.put(newVariable, new Aggregation( entry.getValue().getFunctionHandle(), arguments, Optional.empty(), @@ -459,7 +462,7 @@ private AggregationNode createNonDistinctAggregation( singleGroupingSet(ImmutableList.copyOf(groupByKeys)), ImmutableList.of(), SINGLE, - originalNode.getHashSymbol(), + originalNode.getHashVariable(), Optional.empty()); } @@ -475,75 +478,77 @@ private static IfExpression createIfExpression(Expression left, Expression right private static class AggregateInfo { - private final List groupBySymbols; - private final Symbol mask; - private final Map aggregations; - - // Filled on the way back, these are the symbols corresponding to their distinct or non-distinct original symbols - private Map newNonDistinctAggregateSymbols; - private Symbol newDistinctAggregateSymbol; + private final List groupByVariables; + private final VariableReferenceExpression mask; + private final Map aggregations; + private final TypeProvider types; + + // Filled on the way back, these are the variables corresponding to their distinct or non-distinct original variables + private Map newNonDistinctAggregateVariables; + private VariableReferenceExpression newDistinctAggregateVariable; private boolean foundMarkDistinct; - public AggregateInfo(List groupBySymbols, Symbol mask, Map aggregations) + public AggregateInfo(List groupByVariables, VariableReferenceExpression mask, Map aggregations, TypeProvider types) { - this.groupBySymbols = ImmutableList.copyOf(groupBySymbols); - + this.groupByVariables = ImmutableList.copyOf(groupByVariables); this.mask = mask; - this.aggregations = ImmutableMap.copyOf(aggregations); + this.types = types; } - public List getOriginalNonDistinctAggregateArgs() + public List getOriginalNonDistinctAggregateArgs() { return aggregations.values().stream() .filter(aggregation -> !aggregation.getMask().isPresent()) .flatMap(aggregation -> aggregation.getArguments().stream()) .distinct() .map(Symbol::from) + .map(symbol -> new VariableReferenceExpression(symbol.getName(), types.get(symbol))) .collect(Collectors.toList()); } - public List getOriginalDistinctAggregateArgs() + public List getOriginalDistinctAggregateArgs() { return aggregations.values().stream() .filter(aggregation -> aggregation.getMask().isPresent()) .flatMap(aggregation -> aggregation.getArguments().stream()) .distinct() .map(Symbol::from) + .map(symbol -> new VariableReferenceExpression(symbol.getName(), types.get(symbol))) .collect(Collectors.toList()); } - public Symbol getNewDistinctAggregateSymbol() + public VariableReferenceExpression getNewDistinctAggregateVariable() { - return newDistinctAggregateSymbol; + return newDistinctAggregateVariable; } - public void setNewDistinctAggregateSymbol(Symbol newDistinctAggregateSymbol) + public void setNewDistinctAggregateSymbol(VariableReferenceExpression newDistinctAggregateVariable) { - this.newDistinctAggregateSymbol = newDistinctAggregateSymbol; + this.newDistinctAggregateVariable = newDistinctAggregateVariable; } - public Map getNewNonDistinctAggregateSymbols() + public Map getNewNonDistinctAggregateVariables() { - return newNonDistinctAggregateSymbols; + return newNonDistinctAggregateVariables; } - public void setNewNonDistinctAggregateSymbols(Map newNonDistinctAggregateSymbols) + public void setNewNonDistinctAggregateSymbols(Map newNonDistinctAggregateVariables) { - this.newNonDistinctAggregateSymbols = newNonDistinctAggregateSymbols; + this.newNonDistinctAggregateVariables = newNonDistinctAggregateVariables; } - public Symbol getMask() + public VariableReferenceExpression getMask() { return mask; } - public List getGroupBySymbols() + public List getGroupByVariables() { - return groupBySymbols; + return groupByVariables; } - public Map getAggregations() + public Map getAggregations() { return aggregations; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java index 853be1e79d35e..2f97fecb7a222 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -15,8 +15,10 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -34,7 +36,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; @@ -50,21 +51,22 @@ import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.util.Objects.requireNonNull; public class PlanNodeDecorrelator { private final PlanNodeIdAllocator idAllocator; + private final SymbolAllocator symbolAllocator; private final Lookup lookup; - public PlanNodeDecorrelator(PlanNodeIdAllocator idAllocator, Lookup lookup) + public PlanNodeDecorrelator(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Lookup lookup) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.lookup = requireNonNull(lookup, "lookup is null"); } - public Optional decorrelateFilters(PlanNode node, List correlation) + public Optional decorrelateFilters(PlanNode node, List correlation) { // TODO: when correlations list empty this should return immediately. However this isn't correct // right now, because for nested subqueries correlation list is empty while there might exists usages @@ -80,9 +82,9 @@ public Optional decorrelateFilters(PlanNode node, List private class DecorrelatingVisitor extends InternalPlanVisitor, Void> { - final List correlation; + final List correlation; - DecorrelatingVisitor(List correlation) + DecorrelatingVisitor(List correlation) { this.correlation = requireNonNull(correlation, "correlation is null"); } @@ -129,16 +131,16 @@ public Optional visitFilter(FilterNode node, Void context) childDecorrelationResult.node, castToRowExpression(ExpressionUtils.combineConjuncts(uncorrelatedPredicates))); - Set symbolsToPropagate = Sets.difference(SymbolsExtractor.extractUnique(correlatedPredicates), ImmutableSet.copyOf(correlation)); + Set variablesToPropagate = Sets.difference(SymbolsExtractor.extractUniqueVariable(correlatedPredicates, symbolAllocator.getTypes()), ImmutableSet.copyOf(correlation)); return Optional.of(new DecorrelationResult( newFilterNode, - Sets.union(childDecorrelationResult.symbolsToPropagate, symbolsToPropagate), + Sets.union(childDecorrelationResult.variablesToPropagate, variablesToPropagate), ImmutableList.builder() .addAll(childDecorrelationResult.correlatedPredicates) .addAll(correlatedPredicates) .build(), - ImmutableMultimap.builder() - .putAll(childDecorrelationResult.correlatedSymbolsMapping) + ImmutableMultimap.builder() + .putAll(childDecorrelationResult.correlatedVariablesMapping) .putAll(extractCorrelatedSymbolsMapping(correlatedPredicates)) .build(), childDecorrelationResult.atMostSingleRow)); @@ -161,10 +163,11 @@ public Optional visitLimit(LimitNode node, Void context) return Optional.empty(); } - Set constantSymbols = childDecorrelationResult.getConstantSymbols(); + Set constantVariables = childDecorrelationResult.getConstantVariables(); PlanNode decorrelatedChildNode = childDecorrelationResult.node; - if (constantSymbols.isEmpty() || !constantSymbols.containsAll(decorrelatedChildNode.getOutputSymbols())) { + if (constantVariables.isEmpty() || + !constantVariables.containsAll(decorrelatedChildNode.getOutputVariables())) { return Optional.empty(); } @@ -173,7 +176,7 @@ public Optional visitLimit(LimitNode node, Void context) idAllocator.getNextId(), decorrelatedChildNode, ImmutableMap.of(), - singleGroupingSet(decorrelatedChildNode.getOutputSymbols()), + singleGroupingSet(decorrelatedChildNode.getOutputVariables()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), @@ -181,9 +184,9 @@ public Optional visitLimit(LimitNode node, Void context) return Optional.of(new DecorrelationResult( aggregationNode, - childDecorrelationResult.symbolsToPropagate, + childDecorrelationResult.variablesToPropagate, childDecorrelationResult.correlatedPredicates, - childDecorrelationResult.correlatedSymbolsMapping, + childDecorrelationResult.correlatedVariablesMapping, true)); } @@ -207,17 +210,17 @@ public Optional visitAggregation(AggregationNode node, Void } DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get(); - Set constantSymbols = childDecorrelationResult.getConstantSymbols(); + Set constantVariables = childDecorrelationResult.getConstantVariables(); AggregationNode decorrelatedAggregation = childDecorrelationResult.getCorrelatedSymbolMapper() .map(node, childDecorrelationResult.node); - Set groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys()); - List symbolsToAdd = childDecorrelationResult.symbolsToPropagate.stream() - .filter(symbol -> !groupingKeys.contains(symbol)) + Set groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys()); + List variablesToAdd = childDecorrelationResult.variablesToPropagate.stream() + .filter(variable -> !groupingKeys.contains(variable)) .collect(toImmutableList()); - if (!constantSymbols.containsAll(symbolsToAdd)) { + if (!constantVariables.containsAll(variablesToAdd)) { return Optional.empty(); } @@ -225,23 +228,23 @@ public Optional visitAggregation(AggregationNode node, Void decorrelatedAggregation.getId(), decorrelatedAggregation.getSource(), decorrelatedAggregation.getAggregations(), - AggregationNode.singleGroupingSet(ImmutableList.builder() + AggregationNode.singleGroupingSet(ImmutableList.builder() .addAll(node.getGroupingKeys()) - .addAll(symbolsToAdd) + .addAll(variablesToAdd) .build()), ImmutableList.of(), decorrelatedAggregation.getStep(), - decorrelatedAggregation.getHashSymbol(), - decorrelatedAggregation.getGroupIdSymbol()); + decorrelatedAggregation.getHashVariable(), + decorrelatedAggregation.getGroupIdVariable()); boolean atMostSingleRow = newAggregation.getGroupingSetCount() == 1 - && constantSymbols.containsAll(newAggregation.getGroupingKeys()); + && constantVariables.containsAll(newAggregation.getGroupingKeys()); return Optional.of(new DecorrelationResult( newAggregation, - childDecorrelationResult.symbolsToPropagate, + childDecorrelationResult.variablesToPropagate, childDecorrelationResult.correlatedPredicates, - childDecorrelationResult.correlatedSymbolsMapping, + childDecorrelationResult.correlatedVariablesMapping, atMostSingleRow)); } @@ -254,28 +257,28 @@ public Optional visitProject(ProjectNode node, Void context } DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get(); - Set nodeOutputSymbols = ImmutableSet.copyOf(node.getOutputSymbols()); - List symbolsToAdd = childDecorrelationResult.symbolsToPropagate.stream() - .filter(symbol -> !nodeOutputSymbols.contains(symbol)) + Set nodeOutputVariables = ImmutableSet.copyOf(node.getOutputVariables()); + List variablesToAdd = childDecorrelationResult.variablesToPropagate.stream() + .filter(variable -> !nodeOutputVariables.contains(variable)) .collect(toImmutableList()); Assignments assignments = Assignments.builder() .putAll(node.getAssignments()) - .putIdentities(symbolsToAdd) + .putIdentities(variablesToAdd) .build(); return Optional.of(new DecorrelationResult( new ProjectNode(idAllocator.getNextId(), childDecorrelationResult.node, assignments), - childDecorrelationResult.symbolsToPropagate, + childDecorrelationResult.variablesToPropagate, childDecorrelationResult.correlatedPredicates, - childDecorrelationResult.correlatedSymbolsMapping, + childDecorrelationResult.correlatedVariablesMapping, childDecorrelationResult.atMostSingleRow)); } - private Multimap extractCorrelatedSymbolsMapping(List correlatedConjuncts) + private Multimap extractCorrelatedSymbolsMapping(List correlatedConjuncts) { // TODO: handle coercions and non-direct column references - ImmutableMultimap.Builder mapping = ImmutableMultimap.builder(); + ImmutableMultimap.Builder mapping = ImmutableMultimap.builder(); for (Expression conjunct : correlatedConjuncts) { if (!(conjunct instanceof ComparisonExpression)) { continue; @@ -288,8 +291,8 @@ private Multimap extractCorrelatedSymbolsMapping(List extractCorrelatedSymbolsMapping(List symbolsToPropagate; + final Set variablesToPropagate; final List correlatedPredicates; // mapping from correlated symbols to their uncorrelated equivalence - final Multimap correlatedSymbolsMapping; + final Multimap correlatedVariablesMapping; // If a subquery has at most single row for any correlation values? final boolean atMostSingleRow; - DecorrelationResult(PlanNode node, Set symbolsToPropagate, List correlatedPredicates, Multimap correlatedSymbolsMapping, boolean atMostSingleRow) + DecorrelationResult( + PlanNode node, + Set variablesToPropagate, + List correlatedPredicates, + Multimap correlatedVariablesMapping, + boolean atMostSingleRow) { this.node = node; - this.symbolsToPropagate = symbolsToPropagate; + this.variablesToPropagate = variablesToPropagate; this.correlatedPredicates = correlatedPredicates; this.atMostSingleRow = atMostSingleRow; - this.correlatedSymbolsMapping = correlatedSymbolsMapping; - checkState(symbolsToPropagate.containsAll(correlatedSymbolsMapping.values()), "Expected symbols to propagate to contain all constant symbols"); + this.correlatedVariablesMapping = correlatedVariablesMapping; + checkState(variablesToPropagate.containsAll(correlatedVariablesMapping.values()), "Expected symbols to propagate to contain all constant symbols"); } SymbolMapper getCorrelatedSymbolMapper() { - return new SymbolMapper(correlatedSymbolsMapping.asMap().entrySet().stream() - .collect(toImmutableMap(Map.Entry::getKey, symbols -> Iterables.getLast(symbols.getValue())))); + SymbolMapper.Builder builder = SymbolMapper.builder(); + correlatedVariablesMapping.forEach(builder::put); + return builder.build(); } /** * @return constant symbols from a perspective of a subquery */ - Set getConstantSymbols() + Set getConstantVariables() { - return ImmutableSet.copyOf(correlatedSymbolsMapping.values()); + return ImmutableSet.copyOf(correlatedVariablesMapping.values()); } } private Optional decorrelatedNode( List correlatedPredicates, PlanNode node, - List correlation) + List correlation) { if (containsCorrelation(node, correlation)) { // node is still correlated ; / @@ -357,9 +366,9 @@ private Optional decorrelatedNode( return Optional.of(new DecorrelatedNode(correlatedPredicates, node)); } - private boolean containsCorrelation(PlanNode node, List correlation) + private boolean containsCorrelation(PlanNode node, List correlation) { - return Sets.union(SymbolsExtractor.extractUnique(node, lookup), SymbolsExtractor.extractOutputSymbols(node, lookup)).stream().anyMatch(correlation::contains); + return Sets.union(SymbolsExtractor.extractUniqueVariable(node, lookup, symbolAllocator.getTypes()), SymbolsExtractor.extractOutputVariables(node, lookup, symbolAllocator.getTypes())).stream().anyMatch(correlation::contains); } public static class DecorrelatedNode diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java index 06794275afad0..b56aa3092e3fd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java @@ -17,6 +17,7 @@ import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.EffectivePredicateExtractor; @@ -73,7 +74,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -83,7 +83,7 @@ import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; import static com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator.isDeterministic; -import static com.facebook.presto.sql.planner.ExpressionSymbolInliner.inlineSymbols; +import static com.facebook.presto.sql.planner.ExpressionVariableInliner.inlineVariables; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; @@ -98,9 +98,11 @@ import static com.google.common.base.Predicates.in; import static com.google.common.base.Predicates.not; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.filter; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; public class PredicatePushDown implements PlanOptimizer @@ -183,14 +185,14 @@ public PlanNode visitExchange(ExchangeNode node, RewriteContext cont boolean modified = false; ImmutableList.Builder builder = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { - Map outputsToInputs = new HashMap<>(); + Map outputsToInputs = new HashMap<>(); for (int index = 0; index < node.getInputs().get(i).size(); index++) { outputsToInputs.put( - node.getOutputSymbols().get(index), - node.getInputs().get(i).get(index).toSymbolReference()); + node.getOutputVariables().get(index), + new SymbolReference(node.getInputs().get(i).get(index).getName())); } - Expression sourcePredicate = inlineSymbols(outputsToInputs, context.get()); + Expression sourcePredicate = inlineVariables(outputsToInputs, context.get(), types); PlanNode source = node.getSources().get(i); PlanNode rewrittenSource = context.rewrite(source, sourcePredicate); if (rewrittenSource != source) { @@ -216,8 +218,6 @@ public PlanNode visitExchange(ExchangeNode node, RewriteContext cont @Override public PlanNode visitWindow(WindowNode node, RewriteContext context) { - List partitionSymbols = node.getPartitionBy(); - // TODO: This could be broader. We can push down conjucts if they are constant for all rows in a window partition. // The simplest way to guarantee this is if the conjucts are deterministic functions of the partitioning symbols. // This can leave out cases where they're both functions of some set of common expressions and the partitioning @@ -225,8 +225,8 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) // pre-projected symbols. Predicate isSupported = conjunct -> ExpressionDeterminismEvaluator.isDeterministic(conjunct) && - SymbolsExtractor.extractUnique(conjunct).stream() - .allMatch(partitionSymbols::contains); + SymbolsExtractor.extractUniqueVariable(conjunct, types).stream() + .allMatch(node.getPartitionBy()::contains); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported)); @@ -242,13 +242,12 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) @Override public PlanNode visitProject(ProjectNode node, RewriteContext context) { - Set deterministicSymbols = node.getAssignments().entrySet().stream() + Set deterministicVariables = node.getAssignments().entrySet().stream() .filter(entry -> ExpressionDeterminismEvaluator.isDeterministic(entry.getValue())) .map(Map.Entry::getKey) .collect(Collectors.toSet()); - Predicate deterministic = conjunct -> SymbolsExtractor.extractUnique(conjunct).stream() - .allMatch(deterministicSymbols::contains); + Predicate deterministic = conjunct -> deterministicVariables.containsAll(SymbolsExtractor.extractUniqueVariable(conjunct, types)); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(deterministic)); @@ -262,7 +261,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex .collect(Collectors.partitioningBy(expression -> isInliningCandidate(expression, node))); List inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream() - .map(entry -> inlineSymbols(node.getAssignments().getMap(), entry)) + .map(entry -> inlineVariables(node.getAssignments().getMap(), entry, types)) .collect(Collectors.toList()); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(inlinedDeterministicConjuncts)); @@ -290,10 +289,10 @@ private boolean isInliningCandidate(Expression expression, ProjectNode node) // 1. references to simple constants // 2. references to complex expressions that appear only once // which come from the node, as opposed to an enclosing scope. - Set childOutputSet = ImmutableSet.copyOf(node.getOutputSymbols()); - Map dependencies = SymbolsExtractor.extractAll(expression).stream() + Set childOutputSet = ImmutableSet.copyOf(node.getOutputVariables()); + Map dependencies = SymbolsExtractor.extractAllVariable(expression, types).stream() .filter(childOutputSet::contains) - .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); + .collect(Collectors.groupingBy(identity(), Collectors.counting())); return dependencies.entrySet().stream() .allMatch(entry -> entry.getValue() == 1 || node.getAssignments().get(entry.getKey()) instanceof Literal); @@ -302,17 +301,17 @@ private boolean isInliningCandidate(Expression expression, ProjectNode node) @Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) { - Map commonGroupingSymbolMapping = node.getGroupingColumns().entrySet().stream() + Map commonGroupingVariableMapping = node.getGroupingColumns().entrySet().stream() .filter(entry -> node.getCommonGroupingColumns().contains(entry.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + .collect(Collectors.toMap(Map.Entry::getKey, entry -> new SymbolReference(entry.getValue().getName()))); - Predicate pushdownEligiblePredicate = conjunct -> SymbolsExtractor.extractUnique(conjunct).stream() - .allMatch(commonGroupingSymbolMapping.keySet()::contains); + Predicate pushdownEligiblePredicate = conjunct -> SymbolsExtractor.extractUniqueVariable(conjunct, types).stream() + .allMatch(commonGroupingVariableMapping.keySet()::contains); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(pushdownEligiblePredicate)); // Push down conjuncts from the inherited predicate that apply to common grouping symbols - PlanNode rewrittenNode = context.defaultRewrite(node, inlineSymbols(commonGroupingSymbolMapping, combineConjuncts(conjuncts.get(true)))); + PlanNode rewrittenNode = context.defaultRewrite(node, inlineVariables(commonGroupingVariableMapping, combineConjuncts(conjuncts.get(true)), types)); // All other conjuncts, if any, will be in the filter node. if (!conjuncts.get(false).isEmpty()) { @@ -325,9 +324,9 @@ public PlanNode visitGroupId(GroupIdNode node, RewriteContext contex @Override public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context) { - Set pushDownableSymbols = ImmutableSet.copyOf(node.getDistinctSymbols()); + Set pushDownableVariables = ImmutableSet.copyOf(node.getDistinctVariables()); Map> conjuncts = extractConjuncts(context.get()).stream() - .collect(Collectors.partitioningBy(conjunct -> SymbolsExtractor.extractUnique(conjunct).stream().allMatch(pushDownableSymbols::contains))); + .collect(Collectors.partitioningBy(conjunct -> pushDownableVariables.containsAll(SymbolsExtractor.extractUniqueVariable(conjunct, types)))); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(conjuncts.get(true))); @@ -349,7 +348,7 @@ public PlanNode visitUnion(UnionNode node, RewriteContext context) boolean modified = false; ImmutableList.Builder builder = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { - Expression sourcePredicate = inlineSymbols(Maps.transformValues(node.sourceSymbolMap(i), Symbol::toSymbolReference), context.get()); + Expression sourcePredicate = inlineVariables(Maps.transformValues(node.sourceVariableMap(i), variable -> new SymbolReference(variable.getName())), context.get(), types); PlanNode source = node.getSources().get(i); PlanNode rewrittenSource = context.rewrite(source, sourcePredicate); if (rewrittenSource != source) { @@ -359,7 +358,7 @@ public PlanNode visitUnion(UnionNode node, RewriteContext context) } if (modified) { - return new UnionNode(node.getId(), builder.build(), node.getSymbolMapping(), node.getOutputSymbols()); + return new UnionNode(node.getId(), builder.build(), node.getVariableMapping()); } return node; @@ -391,8 +390,8 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // See if we can rewrite outer joins in terms of a plain inner join node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate); - Expression leftEffectivePredicate = effectivePredicateExtractor.extract(node.getLeft()); - Expression rightEffectivePredicate = effectivePredicateExtractor.extract(node.getRight()); + Expression leftEffectivePredicate = effectivePredicateExtractor.extract(node.getLeft(), types); + Expression rightEffectivePredicate = effectivePredicateExtractor.extract(node.getRight(), types); Expression joinPredicate = extractJoinPredicate(node); Expression leftPredicate; @@ -406,7 +405,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) leftEffectivePredicate, rightEffectivePredicate, joinPredicate, - node.getLeft().getOutputSymbols()); + node.getLeft().getOutputVariables()); leftPredicate = innerJoinPushDownResult.getLeftPredicate(); rightPredicate = innerJoinPushDownResult.getRightPredicate(); postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate(); @@ -417,7 +416,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) leftEffectivePredicate, rightEffectivePredicate, joinPredicate, - node.getLeft().getOutputSymbols()); + node.getLeft().getOutputVariables()); leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate(); rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate(); postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate(); @@ -428,7 +427,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) rightEffectivePredicate, leftEffectivePredicate, joinPredicate, - node.getRight().getOutputSymbols()); + node.getRight().getOutputVariables()); leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate(); rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate(); postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate(); @@ -458,36 +457,36 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // Create identity projections for all existing symbols Assignments.Builder leftProjections = Assignments.builder(); leftProjections.putAll(node.getLeft() - .getOutputSymbols().stream() - .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); + .getOutputVariables().stream() + .collect(Collectors.toMap(identity(), variable -> new SymbolReference(variable.getName())))); Assignments.Builder rightProjections = Assignments.builder(); rightProjections.putAll(node.getRight() - .getOutputSymbols().stream() - .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); + .getOutputVariables().stream() + .collect(Collectors.toMap(identity(), variable -> new SymbolReference(variable.getName())))); // Create new projections for the new join clauses List equiJoinClauses = new ArrayList<>(); ImmutableList.Builder joinFilterBuilder = ImmutableList.builder(); for (Expression conjunct : extractConjuncts(newJoinPredicate)) { - if (joinEqualityExpression(node.getLeft().getOutputSymbols()).test(conjunct)) { + if (joinEqualityExpression(node.getLeft().getOutputVariables()).test(conjunct)) { ComparisonExpression equality = (ComparisonExpression) conjunct; - boolean alignedComparison = Iterables.all(SymbolsExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols())); + boolean alignedComparison = Iterables.all(SymbolsExtractor.extractUniqueVariable(equality.getLeft(), types), in(node.getLeft().getOutputVariables())); Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight(); Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft(); - Symbol leftSymbol = symbolForExpression(leftExpression); - if (!node.getLeft().getOutputSymbols().contains(leftSymbol)) { - leftProjections.put(leftSymbol, leftExpression); + VariableReferenceExpression leftVariable = variableForExpression(leftExpression); + if (!node.getLeft().getOutputVariables().contains(leftVariable)) { + leftProjections.put(leftVariable, leftExpression); } - Symbol rightSymbol = symbolForExpression(rightExpression); - if (!node.getRight().getOutputSymbols().contains(rightSymbol)) { - rightProjections.put(rightSymbol, rightExpression); + VariableReferenceExpression rightVariable = variableForExpression(rightExpression); + if (!node.getRight().getOutputVariables().contains(rightVariable)) { + rightProjections.put(rightVariable, rightExpression); } - equiJoinClauses.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol)); + equiJoinClauses.add(new JoinNode.EquiJoinClause(leftVariable, rightVariable)); } else { joinFilterBuilder.add(conjunct); @@ -537,13 +536,13 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) leftSource, rightSource, equiJoinClauses, - ImmutableList.builder() - .addAll(leftSource.getOutputSymbols()) - .addAll(rightSource.getOutputSymbols()) + ImmutableList.builder() + .addAll(leftSource.getOutputVariables()) + .addAll(rightSource.getOutputVariables()) .build(), newJoinFilter.map(OriginalExpressionUtils::castToRowExpression), - node.getLeftHashSymbol(), - node.getRightHashSymbol(), + node.getLeftHashVariable(), + node.getRightHashVariable(), distributionType); } @@ -551,8 +550,8 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) output = new FilterNode(idAllocator.getNextId(), output, castToRowExpression(postJoinPredicate)); } - if (!node.getOutputSymbols().equals(output.getOutputSymbols())) { - output = new ProjectNode(idAllocator.getNextId(), output, Assignments.identity(node.getOutputSymbols())); + if (!node.getOutputVariables().equals(output.getOutputVariables())) { + output = new ProjectNode(idAllocator.getNextId(), output, Assignments.identity(node.getOutputVariables())); } return output; @@ -564,12 +563,21 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext key, Symbol::toSymbolReference))); + .getOutputVariables().stream() + .collect(Collectors.toMap(identity(), variable -> new SymbolReference(variable.getName())))); Assignments.Builder rightProjections = Assignments.builder(); rightProjections.putAll(node.getRight() - .getOutputSymbols().stream() - .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); + .getOutputVariables().stream() + .collect(Collectors.toMap(identity(), variable -> new SymbolReference(variable.getName())))); leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build()); rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build()); @@ -635,10 +643,10 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext outerSymbols) + private OuterJoinPushDownResult processLimitedOuterJoin(Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, Expression joinPredicate, Collection outerVariables) { - checkArgument(Iterables.all(SymbolsExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)), "outerEffectivePredicate must only contain symbols from outerSymbols"); - checkArgument(Iterables.all(SymbolsExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))), "innerEffectivePredicate must not contain symbols from outerSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUniqueVariable(outerEffectivePredicate, types), in(outerVariables)), "outerEffectivePredicate must only contain variables from outerVariables"); + checkArgument(Iterables.all(SymbolsExtractor.extractUniqueVariable(innerEffectivePredicate, types), not(in(outerVariables))), "innerEffectivePredicate must not contain variables from outerVariables"); ImmutableList.Builder outerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder innerPushdownConjuncts = ImmutableList.builder(); @@ -681,18 +689,18 @@ private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheri EqualityInference inheritedInference = createEqualityInference(inheritedPredicate); EqualityInference outerInference = createEqualityInference(inheritedPredicate, outerEffectivePredicate); - EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(outerSymbols)); + EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(outerVariables), types); Expression outerOnlyInheritedEqualities = combineConjuncts(equalityPartition.getScopeEqualities()); EqualityInference potentialNullSymbolInference = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate); // See if we can push inherited predicates down for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { - Expression outerRewritten = outerInference.rewriteExpression(conjunct, in(outerSymbols)); + Expression outerRewritten = outerInference.rewriteExpression(conjunct, in(outerVariables), types); if (outerRewritten != null) { outerPushdownConjuncts.add(outerRewritten); // A conjunct can only be pushed down into an inner side if it can be rewritten in terms of the outer side - Expression innerRewritten = potentialNullSymbolInference.rewriteExpression(outerRewritten, not(in(outerSymbols))); + Expression innerRewritten = potentialNullSymbolInference.rewriteExpression(outerRewritten, not(in(outerVariables)), types); if (innerRewritten != null) { innerPushdownConjuncts.add(innerRewritten); } @@ -708,7 +716,7 @@ private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheri // See if we can push down any outer effective predicates to the inner side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(outerEffectivePredicate)) { - Expression rewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerSymbols))); + Expression rewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerVariables)), types); if (rewritten != null) { innerPushdownConjuncts.add(rewritten); } @@ -716,7 +724,7 @@ private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheri // See if we can push down join predicates to the inner side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) { - Expression innerRewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerSymbols))); + Expression innerRewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerVariables)), types); if (innerRewritten != null) { innerPushdownConjuncts.add(innerRewritten); } @@ -729,10 +737,10 @@ private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheri // SELECT * FROM nation LEFT OUTER JOIN region ON nation.regionkey = region.regionkey and nation.name = region.name WHERE nation.name = 'blah' EqualityInference potentialNullSymbolInferenceWithoutInnerInferred = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate); - innerPushdownConjuncts.addAll(potentialNullSymbolInferenceWithoutInnerInferred.generateEqualitiesPartitionedBy(not(in(outerSymbols))).getScopeEqualities()); + innerPushdownConjuncts.addAll(potentialNullSymbolInferenceWithoutInnerInferred.generateEqualitiesPartitionedBy(not(in(outerVariables)), types).getScopeEqualities()); // TODO: we can further improve simplifying the equalities by considering other relationships from the outer side - EqualityInference.EqualityPartition joinEqualityPartition = createEqualityInference(joinPredicate).generateEqualitiesPartitionedBy(not(in(outerSymbols))); + EqualityInference.EqualityPartition joinEqualityPartition = createEqualityInference(joinPredicate).generateEqualitiesPartitionedBy(not(in(outerVariables)), types); innerPushdownConjuncts.addAll(joinEqualityPartition.getScopeEqualities()); joinConjuncts.addAll(joinEqualityPartition.getScopeComplementEqualities()) .addAll(joinEqualityPartition.getScopeStraddlingEqualities()); @@ -779,10 +787,10 @@ private Expression getPostJoinPredicate() } } - private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, Expression joinPredicate, Collection leftSymbols) + private InnerJoinPushDownResult processInnerJoin(Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, Expression joinPredicate, Collection leftVariables) { - checkArgument(Iterables.all(SymbolsExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols"); - checkArgument(Iterables.all(SymbolsExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUniqueVariable(leftEffectivePredicate, types), in(leftVariables)), "leftEffectivePredicate must only contain variables from leftVariables"); + checkArgument(Iterables.all(SymbolsExtractor.extractUniqueVariable(rightEffectivePredicate, types), not(in(leftVariables))), "rightEffectivePredicate must not contain variables from leftVariables"); ImmutableList.Builder leftPushDownConjuncts = ImmutableList.builder(); ImmutableList.Builder rightPushDownConjuncts = ImmutableList.builder(); @@ -805,12 +813,12 @@ private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPred // Sort through conjuncts in inheritedPredicate that were not used for inference for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { - Expression leftRewrittenConjunct = allInference.rewriteExpression(conjunct, in(leftSymbols)); + Expression leftRewrittenConjunct = allInference.rewriteExpression(conjunct, in(leftVariables), types); if (leftRewrittenConjunct != null) { leftPushDownConjuncts.add(leftRewrittenConjunct); } - Expression rightRewrittenConjunct = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); + Expression rightRewrittenConjunct = allInference.rewriteExpression(conjunct, not(in(leftVariables)), types); if (rightRewrittenConjunct != null) { rightPushDownConjuncts.add(rightRewrittenConjunct); } @@ -823,7 +831,7 @@ private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPred // See if we can push the right effective predicate to the left side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(rightEffectivePredicate)) { - Expression rewritten = allInference.rewriteExpression(conjunct, in(leftSymbols)); + Expression rewritten = allInference.rewriteExpression(conjunct, in(leftVariables), types); if (rewritten != null) { leftPushDownConjuncts.add(rewritten); } @@ -831,7 +839,7 @@ private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPred // See if we can push the left effective predicate to the right side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(leftEffectivePredicate)) { - Expression rewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); + Expression rewritten = allInference.rewriteExpression(conjunct, not(in(leftVariables)), types); if (rewritten != null) { rightPushDownConjuncts.add(rewritten); } @@ -839,12 +847,12 @@ private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPred // See if we can push any parts of the join predicates to either side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) { - Expression leftRewritten = allInference.rewriteExpression(conjunct, in(leftSymbols)); + Expression leftRewritten = allInference.rewriteExpression(conjunct, in(leftVariables), types); if (leftRewritten != null) { leftPushDownConjuncts.add(leftRewritten); } - Expression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); + Expression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftVariables)), types); if (rightRewritten != null) { rightPushDownConjuncts.add(rightRewritten); } @@ -855,9 +863,9 @@ private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPred } // Add equalities from the inference back in - leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeEqualities()); - rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(not(in(leftSymbols))).getScopeEqualities()); - joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(in(leftSymbols)::apply).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate + leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(in(leftVariables), types).getScopeEqualities()); + rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(not(in(leftVariables)), types).getScopeEqualities()); + joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(in(leftVariables)::apply, types).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate return new InnerJoinPushDownResult(combineConjuncts(leftPushDownConjuncts.build()), combineConjuncts(rightPushDownConjuncts.build()), combineConjuncts(joinConjuncts.build()), TRUE_LITERAL); } @@ -910,7 +918,7 @@ private static Expression extractJoinPredicate(JoinNode joinNode) private Type extractType(Expression expression) { - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), /* parameters have already been replaced */WarningCollector.NOOP); + Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, expression, emptyList(), /* parameters have already been replaced */WarningCollector.NOOP); return expressionTypes.get(NodeRef.of(expression)); } @@ -923,30 +931,30 @@ private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheri } if (node.getType() == JoinNode.Type.FULL) { - boolean canConvertToLeftJoin = canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate); - boolean canConvertToRightJoin = canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate); + boolean canConvertToLeftJoin = canConvertOuterToInner(node.getLeft().getOutputVariables(), inheritedPredicate); + boolean canConvertToRightJoin = canConvertOuterToInner(node.getRight().getOutputVariables(), inheritedPredicate); if (!canConvertToLeftJoin && !canConvertToRightJoin) { return node; } if (canConvertToLeftJoin && canConvertToRightJoin) { - return new JoinNode(node.getId(), INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); + return new JoinNode(node.getId(), INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputVariables(), node.getFilter(), node.getLeftHashVariable(), node.getRightHashVariable(), node.getDistributionType()); } else { return new JoinNode(node.getId(), canConvertToLeftJoin ? LEFT : RIGHT, - node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); + node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputVariables(), node.getFilter(), node.getLeftHashVariable(), node.getRightHashVariable(), node.getDistributionType()); } } - if (node.getType() == JoinNode.Type.LEFT && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate) || - node.getType() == JoinNode.Type.RIGHT && !canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate)) { + if (node.getType() == JoinNode.Type.LEFT && !canConvertOuterToInner(node.getRight().getOutputVariables(), inheritedPredicate) || + node.getType() == JoinNode.Type.RIGHT && !canConvertOuterToInner(node.getLeft().getOutputVariables(), inheritedPredicate)) { return node; } - return new JoinNode(node.getId(), JoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); + return new JoinNode(node.getId(), JoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputVariables(), node.getFilter(), node.getLeftHashVariable(), node.getRightHashVariable(), node.getDistributionType()); } - private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Expression inheritedPredicate) + private boolean canConvertOuterToInner(List innerVariablesForOuterJoin, Expression inheritedPredicate) { - Set innerSymbols = ImmutableSet.copyOf(innerSymbolsForOuterJoin); + Set innerSymbols = innerVariablesForOuterJoin.stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableSet()); for (Expression conjunct : extractConjuncts(inheritedPredicate)) { if (ExpressionDeterminismEvaluator.isDeterministic(conjunct)) { // Ignore a conjunct for this test if we can not deterministically get responses from it @@ -969,7 +977,7 @@ private Expression simplifyExpression(Expression expression) session, metadata, sqlParser, - symbolAllocator.getTypes(), + types, expression, emptyList(), /* parameters have already been replaced */ WarningCollector.NOOP); @@ -991,7 +999,7 @@ private Object nullInputEvaluator(final Collection nullSymbols, Expressi session, metadata, sqlParser, - symbolAllocator.getTypes(), + types, expression, emptyList(), /* parameters have already been replaced */ WarningCollector.NOOP); @@ -999,20 +1007,20 @@ private Object nullInputEvaluator(final Collection nullSymbols, Expressi .optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference()); } - private static Predicate joinEqualityExpression(final Collection leftSymbols) + private Predicate joinEqualityExpression(final Collection leftVariables) { return expression -> { // At this point in time, our join predicates need to be deterministic if (isDeterministic(expression) && expression instanceof ComparisonExpression) { ComparisonExpression comparison = (ComparisonExpression) expression; if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) { - Set symbols1 = SymbolsExtractor.extractUnique(comparison.getLeft()); - Set symbols2 = SymbolsExtractor.extractUnique(comparison.getRight()); - if (symbols1.isEmpty() || symbols2.isEmpty()) { + Set variables1 = SymbolsExtractor.extractUniqueVariable(comparison.getLeft(), types); + Set variables2 = SymbolsExtractor.extractUniqueVariable(comparison.getRight(), types); + if (variables1.isEmpty() || variables2.isEmpty()) { return false; } - return (Iterables.all(symbols1, in(leftSymbols)) && Iterables.all(symbols2, not(in(leftSymbols)))) || - (Iterables.all(symbols2, in(leftSymbols)) && Iterables.all(symbols1, not(in(leftSymbols)))); + return (Iterables.all(variables1, in(leftVariables)) && Iterables.all(variables2, not(in(leftVariables)))) || + (Iterables.all(variables2, in(leftVariables)) && Iterables.all(variables1, not(in(leftVariables)))); } } return false; @@ -1023,7 +1031,7 @@ private static Predicate joinEqualityExpression(final Collection context) { Expression inheritedPredicate = context.get(); - if (!extractConjuncts(inheritedPredicate).contains(node.getSemiJoinOutput().toSymbolReference())) { + if (!extractConjuncts(inheritedPredicate).contains(new SymbolReference(node.getSemiJoinOutput().getName()))) { return visitNonFilteringSemiJoin(node, context); } return visitFilteringSemiJoin(node, context); @@ -1042,7 +1050,7 @@ private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext sourceSymbols = node.getSource().getOutputSymbols(); - List filteringSourceSymbols = node.getFilteringSource().getOutputSymbols(); + List sourceVariables = node.getSource().getOutputVariables(); + List filteringSourceVariables = node.getFilteringSource().getOutputVariables(); List sourceConjuncts = new ArrayList<>(); List filteringSourceConjuncts = new ArrayList<>(); @@ -1096,7 +1104,7 @@ private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext filter // See if we can push the filtering source effective predicate to the source side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(filteringSourceEffectivePredicate)) { - Expression rewritten = allInference.rewriteExpression(conjunct, in(sourceSymbols)); + Expression rewritten = allInference.rewriteExpression(conjunct, in(sourceVariables), types); if (rewritten != null) { sourceConjuncts.add(rewritten); } @@ -1127,15 +1135,15 @@ private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext pushdownConjuncts = new ArrayList<>(); List postAggregationConjuncts = new ArrayList<>(); + List groupingKeyVariables = node.getGroupingKeys(); + // Strip out non-deterministic conjuncts postAggregationConjuncts.addAll(ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(ExpressionDeterminismEvaluator::isDeterministic)))); inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate); // Sort non-equality predicates by those that can be pushed down and those that cannot for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { - if (node.getGroupIdSymbol().isPresent() && SymbolsExtractor.extractUnique(conjunct).contains(node.getGroupIdSymbol().get())) { + if (node.getGroupIdVariable().isPresent() && SymbolsExtractor.extractUniqueVariable(conjunct, types).contains(node.getGroupIdVariable().get())) { // aggregation operator synthesizes outputs for group ids corresponding to the global grouping set (i.e., ()), so we // need to preserve any predicates that evaluate the group id to run after the aggregation // TODO: we should be able to infer if conditions on grouping() correspond to global grouping sets to determine whether @@ -1190,7 +1200,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext context) // Sort non-equality predicates by those that can be pushed down and those that cannot for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { - Expression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(node.getReplicateSymbols())); + Expression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(node.getReplicateVariables()), types); if (rewrittenConjunct != null) { pushdownConjuncts.add(rewrittenConjunct); } @@ -1250,7 +1260,7 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext context) } // Add the equality predicates back in - EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getReplicateSymbols())::apply); + EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getReplicateVariables())::apply, types); pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); @@ -1259,7 +1269,7 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext context) PlanNode output = node; if (rewrittenSource != node.getSource()) { - output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateSymbols(), node.getUnnestSymbols(), node.getOrdinalitySymbol()); + output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateVariables(), node.getUnnestVariables(), node.getOrdinalityVariable()); } if (!postUnnestConjuncts.isEmpty()) { output = new FilterNode(idAllocator.getNextId(), output, castToRowExpression(combineConjuncts(postUnnestConjuncts))); @@ -1288,8 +1298,8 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext co @Override public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext context) { - Set predicateSymbols = SymbolsExtractor.extractUnique(context.get()); - checkState(!predicateSymbols.contains(node.getIdColumn()), "UniqueId in predicate is not yet supported"); + Set predicateVariables = SymbolsExtractor.extractUniqueVariable(context.get(), types); + checkState(!predicateVariables.contains(node.getIdVariable()), "UniqueId in predicate is not yet supported"); return context.defaultRewrite(node, context.get()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PreferredProperties.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PreferredProperties.java index 7089a11370a0c..a8e96d3508eb3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PreferredProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PreferredProperties.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.spi.LocalProperty; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Partitioning; -import com.facebook.presto.sql.planner.Symbol; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; @@ -36,11 +36,11 @@ class PreferredProperties { private final Optional globalProperties; - private final List> localProperties; + private final List> localProperties; private PreferredProperties( Optional globalProperties, - List> localProperties) + List> localProperties) { requireNonNull(globalProperties, "globalProperties is null"); requireNonNull(localProperties, "localProperties is null"); @@ -61,14 +61,14 @@ public static PreferredProperties undistributed() .build(); } - public static PreferredProperties partitioned(Set columns) + public static PreferredProperties partitioned(Set columns) { return builder() .global(Global.distributed(PartitioningProperties.partitioned(columns))) .build(); } - public static PreferredProperties partitionedWithNullsAndAnyReplicated(Set columns) + public static PreferredProperties partitionedWithNullsAndAnyReplicated(Set columns) { return builder() .global(Global.distributed(PartitioningProperties.partitioned(columns).withNullsAndAnyReplicated(true))) @@ -96,7 +96,7 @@ public static PreferredProperties partitionedWithNullsAndAnyReplicated(Partition .build(); } - public static PreferredProperties partitionedWithLocal(Set columns, List> localProperties) + public static PreferredProperties partitionedWithLocal(Set columns, List> localProperties) { return builder() .global(Global.distributed(PartitioningProperties.partitioned(columns))) @@ -104,7 +104,7 @@ public static PreferredProperties partitionedWithLocal(Set columns, List .build(); } - public static PreferredProperties undistributedWithLocal(List> localProperties) + public static PreferredProperties undistributedWithLocal(List> localProperties) { return builder() .global(Global.undistributed()) @@ -112,7 +112,7 @@ public static PreferredProperties undistributedWithLocal(List> localProperties) + public static PreferredProperties local(List> localProperties) { return builder() .local(localProperties) @@ -124,14 +124,14 @@ public Optional getGlobalProperties() return globalProperties; } - public List> getLocalProperties() + public List> getLocalProperties() { return localProperties; } public PreferredProperties mergeWithParent(PreferredProperties parent) { - List> newLocal = ImmutableList.>builder() + List> newLocal = ImmutableList.>builder() .addAll(localProperties) .addAll(parent.getLocalProperties()) .build(); @@ -153,10 +153,10 @@ public PreferredProperties mergeWithParent(PreferredProperties parent) return builder.build(); } - public PreferredProperties translate(Function> translator) + public PreferredProperties translate(Function> translator) { Optional newGlobalProperties = globalProperties.map(global -> global.translate(translator)); - List> newLocalProperties = LocalProperties.translate(localProperties, translator); + List> newLocalProperties = LocalProperties.translate(localProperties, translator); return new PreferredProperties(newGlobalProperties, newLocalProperties); } @@ -168,7 +168,7 @@ public static Builder builder() public static class Builder { private Optional globalProperties = Optional.empty(); - private List> localProperties = ImmutableList.of(); + private List> localProperties = ImmutableList.of(); public Builder global(Global globalProperties) { @@ -188,7 +188,7 @@ public Builder global(PreferredProperties other) return this; } - public Builder local(List> localProperties) + public Builder local(List> localProperties) { this.localProperties = ImmutableList.copyOf(localProperties); return this; @@ -262,7 +262,7 @@ public Global mergeWithParent(Global parent) return new Global(distributed, Optional.of(partitioningProperties.get().mergeWithParent(parent.partitioningProperties.get()))); } - public Global translate(Function> translator) + public Global translate(Function> translator) { if (!isDistributed()) { return this; @@ -303,17 +303,17 @@ public String toString() @Immutable public static final class PartitioningProperties { - private final Set partitioningColumns; + private final Set partitioningColumns; private final Optional partitioning; // Specific partitioning requested private final boolean nullsAndAnyReplicated; - private PartitioningProperties(Set partitioningColumns, Optional partitioning, boolean nullsAndAnyReplicated) + private PartitioningProperties(Set partitioningColumns, Optional partitioning, boolean nullsAndAnyReplicated) { this.partitioningColumns = ImmutableSet.copyOf(requireNonNull(partitioningColumns, "partitioningColumns is null")); this.partitioning = requireNonNull(partitioning, "function is null"); this.nullsAndAnyReplicated = nullsAndAnyReplicated; - checkArgument(!partitioning.isPresent() || partitioning.get().getColumns().equals(partitioningColumns), "Partitioning input must match partitioningColumns"); + checkArgument(!partitioning.isPresent() || partitioning.get().getVariableReferences().equals(partitioningColumns), "Partitioning input must match partitioningColumns"); } public PartitioningProperties withNullsAndAnyReplicated(boolean nullsAndAnyReplicated) @@ -323,10 +323,10 @@ public PartitioningProperties withNullsAndAnyReplicated(boolean nullsAndAnyRepli public static PartitioningProperties partitioned(Partitioning partitioning) { - return new PartitioningProperties(partitioning.getColumns(), Optional.of(partitioning), false); + return new PartitioningProperties(partitioning.getVariableReferences(), Optional.of(partitioning), false); } - public static PartitioningProperties partitioned(Set columns) + public static PartitioningProperties partitioned(Set columns) { return new PartitioningProperties(columns, Optional.empty(), false); } @@ -336,7 +336,7 @@ public static PartitioningProperties singlePartition() return partitioned(ImmutableSet.of()); } - public Set getPartitioningColumns() + public Set getPartitioningColumns() { return partitioningColumns; } @@ -370,13 +370,13 @@ public PartitioningProperties mergeWithParent(PartitioningProperties parent) } // Otherwise partition on any common columns if available - Set common = Sets.intersection(partitioningColumns, parent.partitioningColumns); + Set common = Sets.intersection(partitioningColumns, parent.partitioningColumns); return common.isEmpty() ? this : partitioned(common).withNullsAndAnyReplicated(nullsAndAnyReplicated); } - public Optional translate(Function> translator) + public Optional translate(Function> translator) { - Set newPartitioningColumns = partitioningColumns.stream() + Set newPartitioningColumns = partitioningColumns.stream() .map(translator) .filter(Optional::isPresent) .map(Optional::get) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index 2367b43e62183..389b0126e2562 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -24,8 +24,9 @@ import com.facebook.presto.spi.GroupingProperty; import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.SortingProperty; -import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.ExpressionDomainTranslator; @@ -89,7 +90,7 @@ import java.util.stream.Collectors; import static com.facebook.presto.SystemSessionProperties.planWithTableNodePartitioning; -import static com.facebook.presto.spi.predicate.TupleDomain.extractFixedValues; +import static com.facebook.presto.spi.predicate.TupleDomain.extractFixedValuesToConstantExpressions; import static com.facebook.presto.spi.relation.DomainTranslator.BASIC_COLUMN_EXTRACTOR; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.ARBITRARY_DISTRIBUTION; @@ -98,12 +99,15 @@ import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.partitionedOn; import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.singleStreamPartition; import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.streamPartitionedOn; +import static com.facebook.presto.sql.planner.optimizations.AddExchanges.computeIdentityTranslations; +import static com.facebook.presto.sql.planner.optimizations.AddExchanges.toVariableReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; @@ -126,15 +130,15 @@ public static ActualProperties deriveProperties(PlanNode node, List - verify(node.getOutputSymbols().containsAll(partitioning.getColumns()), "Node-level partitioning properties contain columns not present in node's output")); + verify(node.getOutputVariables().containsAll(partitioning.getVariableReferences()), "Node-level partitioning properties contain columns not present in node's output")); - verify(node.getOutputSymbols().containsAll(output.getConstants().keySet()), "Node-level constant properties contain columns not present in node's output"); + verify(node.getOutputVariables().containsAll(output.getConstants().keySet()), "Node-level constant properties contain columns not present in node's output"); - Set localPropertyColumns = output.getLocalProperties().stream() + Set localPropertyColumns = output.getLocalProperties().stream() .flatMap(property -> property.getColumns().stream()) .collect(Collectors.toSet()); - verify(node.getOutputSymbols().containsAll(localPropertyColumns), "Node-level local properties contain columns not present in node's output"); + verify(node.getOutputVariables().containsAll(localPropertyColumns), "Node-level local properties contain columns not present in node's output"); return output; } @@ -177,7 +181,7 @@ public ActualProperties visitExplainAnalyze(ExplainAnalyzeNode node, List inputProperties) { return Iterables.getOnlyElement(inputProperties) - .translate(column -> PropertyDerivations.filterIfMissing(node.getOutputSymbols(), column)); + .translate(column -> PropertyDerivations.filterIfMissing(node.getOutputVariables(), column)); } @Override @@ -191,10 +195,10 @@ public ActualProperties visitAssignUniqueId(AssignUniqueId node, List> newLocalProperties = ImmutableList.builder(); + ImmutableList.Builder> newLocalProperties = ImmutableList.builder(); newLocalProperties.addAll(properties.getLocalProperties()); - newLocalProperties.add(new GroupingProperty<>(ImmutableList.of(node.getIdColumn()))); - node.getSource().getOutputSymbols().stream() + newLocalProperties.add(new GroupingProperty<>(ImmutableList.of(node.getIdVariable()))); + node.getSource().getOutputVariables().stream() .forEach(column -> newLocalProperties.add(new ConstantProperty<>(column))); if (properties.getNodePartitioning().isPresent()) { @@ -205,7 +209,7 @@ public ActualProperties visitAssignUniqueId(AssignUniqueId node, List inpu return properties; } - ImmutableList.Builder> localProperties = ImmutableList.builder(); + ImmutableList.Builder> localProperties = ImmutableList.builder(); // If the WindowNode has pre-partitioned inputs, then it will not change the order of those inputs at output, // so we should just propagate those underlying local properties that guarantee the pre-partitioning. // TODO: come up with a more general form of this operation for other streaming operators if (!node.getPrePartitionedInputs().isEmpty()) { - GroupingProperty prePartitionedProperty = new GroupingProperty<>(node.getPrePartitionedInputs()); - for (LocalProperty localProperty : properties.getLocalProperties()) { + GroupingProperty prePartitionedProperty = new GroupingProperty<>(node.getPrePartitionedInputs()); + for (LocalProperty localProperty : properties.getLocalProperties()) { if (!prePartitionedProperty.isSimplifiedBy(localProperty)) { break; } @@ -272,8 +276,8 @@ public ActualProperties visitWindow(WindowNode node, List inpu @Override public ActualProperties visitGroupId(GroupIdNode node, List inputProperties) { - Map inputToOutputMappings = new HashMap<>(); - for (Map.Entry setMapping : node.getGroupingColumns().entrySet()) { + Map inputToOutputMappings = new HashMap<>(); + for (Map.Entry setMapping : node.getGroupingColumns().entrySet()) { if (node.getCommonGroupingColumns().contains(setMapping.getKey())) { // TODO: Add support for translating a property on a single column to multiple columns // when GroupIdNode is copying a single input grouping column into multiple output grouping columns (i.e. aliases), this is basically picking one arbitrarily @@ -283,7 +287,7 @@ public ActualProperties visitGroupId(GroupIdNode node, List in // TODO: Add support for translating a property on a single column to multiple columns // this is deliberately placed after the grouping columns, because preserving properties has a bigger perf impact - for (Symbol argument : node.getAggregationArguments()) { + for (VariableReferenceExpression argument : node.getAggregationArguments()) { inputToOutputMappings.putIfAbsent(argument, argument); } @@ -295,7 +299,7 @@ public ActualProperties visitAggregation(AggregationNode node, List node.getGroupingKeys().contains(symbol) ? Optional.of(symbol) : Optional.empty()); + ActualProperties translated = properties.translate(variable -> node.getGroupingKeys().contains(variable) ? Optional.of(variable) : Optional.empty()); return ActualProperties.builderFrom(translated) .local(LocalProperties.grouped(node.getGroupingKeys())) @@ -313,9 +317,9 @@ public ActualProperties visitTopNRowNumber(TopNRowNumberNode node, List> localProperties = ImmutableList.builder(); + ImmutableList.Builder> localProperties = ImmutableList.builder(); localProperties.add(new GroupingProperty<>(node.getPartitionBy())); - for (Symbol column : node.getOrderingScheme().getOrderBy()) { + for (VariableReferenceExpression column : node.getOrderingScheme().getOrderBy()) { localProperties.add(new SortingProperty<>(column, node.getOrderingScheme().getOrdering(column))); } @@ -329,7 +333,7 @@ public ActualProperties visitTopN(TopNNode node, List inputPro { ActualProperties properties = Iterables.getOnlyElement(inputProperties); - List> localProperties = node.getOrderingScheme().getOrderBy().stream() + List> localProperties = node.getOrderingScheme().getOrderBy().stream() .map(column -> new SortingProperty<>(column, node.getOrderingScheme().getOrdering(column))) .collect(toImmutableList()); @@ -343,7 +347,7 @@ public ActualProperties visitSort(SortNode node, List inputPro { ActualProperties properties = Iterables.getOnlyElement(inputProperties); - List> localProperties = node.getOrderingScheme().getOrderBy().stream() + List> localProperties = node.getOrderingScheme().getOrderBy().stream() .map(column -> new SortingProperty<>(column, node.getOrderingScheme().getOrdering(column))) .collect(toImmutableList()); @@ -364,7 +368,7 @@ public ActualProperties visitDistinctLimit(DistinctLimitNode node, List inputPro { ActualProperties probeProperties = inputProperties.get(0); ActualProperties buildProperties = inputProperties.get(1); + List outputVariableReferences = node.getOutputVariables(); boolean unordered = spillPossible(session, node.getType()); switch (node.getType()) { case INNER: - probeProperties = probeProperties.translate(column -> filterOrRewrite(node.getOutputSymbols(), node.getCriteria(), column)); - buildProperties = buildProperties.translate(column -> filterOrRewrite(node.getOutputSymbols(), node.getCriteria(), column)); + probeProperties = probeProperties.translate(column -> filterOrRewrite(outputVariableReferences, node.getCriteria(), column)); + buildProperties = buildProperties.translate(column -> filterOrRewrite(outputVariableReferences, node.getCriteria(), column)); - Map constants = new HashMap<>(); + Map constants = new HashMap<>(); constants.putAll(probeProperties.getConstants()); constants.putAll(buildProperties.getConstants()); @@ -423,13 +428,13 @@ public ActualProperties visitJoin(JoinNode node, List inputPro .unordered(unordered) .build(); case LEFT: - return ActualProperties.builderFrom(probeProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column))) + return ActualProperties.builderFrom(probeProperties.translate(column -> filterIfMissing(outputVariableReferences, column))) .unordered(unordered) .build(); case RIGHT: - buildProperties = buildProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column)); + buildProperties = buildProperties.translate(column -> filterIfMissing(node.getOutputVariables(), column)); - return ActualProperties.builderFrom(buildProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column))) + return ActualProperties.builderFrom(buildProperties.translate(column -> filterIfMissing(outputVariableReferences, column))) .local(ImmutableList.of()) .unordered(true) .build(); @@ -455,13 +460,14 @@ public ActualProperties visitSpatialJoin(SpatialJoinNode node, List outputs = node.getOutputVariables(); switch (node.getType()) { case INNER: - probeProperties = probeProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column)); - buildProperties = buildProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column)); + probeProperties = probeProperties.translate(column -> filterIfMissing(outputs, column)); + buildProperties = buildProperties.translate(column -> filterIfMissing(outputs, column)); - Map constants = new HashMap<>(); + Map constants = new HashMap<>(); constants.putAll(probeProperties.getConstants()); constants.putAll(buildProperties.getConstants()); @@ -469,7 +475,7 @@ public ActualProperties visitSpatialJoin(SpatialJoinNode node, List filterIfMissing(node.getOutputSymbols(), column))) + return ActualProperties.builderFrom(probeProperties.translate(column -> filterIfMissing(outputs, column))) .build(); default: throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType()); @@ -486,7 +492,7 @@ public ActualProperties visitIndexJoin(IndexJoinNode node, Listbuilder() + .constants(ImmutableMap.builder() .putAll(probeProperties.getConstants()) .putAll(indexProperties.getConstants()) .build()) @@ -508,34 +514,29 @@ public ActualProperties visitIndexSource(IndexSourceNode node, List exchangeInputToOutput(ExchangeNode node, int sourceIndex) - { - List inputSymbols = node.getInputs().get(sourceIndex); - Map inputToOutput = new HashMap<>(); - for (int i = 0; i < node.getOutputSymbols().size(); i++) { - inputToOutput.put(inputSymbols.get(i), node.getOutputSymbols().get(i)); - } - return inputToOutput; - } - @Override public ActualProperties visitExchange(ExchangeNode node, List inputProperties) { checkArgument(!node.getScope().isRemote() || inputProperties.stream().noneMatch(ActualProperties::isNullsAndAnyReplicated), "Null-and-any replicated inputs should not be remotely exchanged"); - Set> entries = null; + Set> entries = null; for (int sourceIndex = 0; sourceIndex < node.getSources().size(); sourceIndex++) { - Map inputToOutput = exchangeInputToOutput(node, sourceIndex); - ActualProperties translated = inputProperties.get(sourceIndex).translate(symbol -> Optional.ofNullable(inputToOutput.get(symbol))); + List inputVariables = node.getInputs().get(sourceIndex); + Map inputToOutput = new HashMap<>(); + for (int i = 0; i < node.getOutputVariables().size(); i++) { + inputToOutput.put(inputVariables.get(i), node.getOutputVariables().get(i)); + } + + ActualProperties translated = inputProperties.get(sourceIndex).translate(variable -> Optional.ofNullable(inputToOutput.get(variable))); entries = (entries == null) ? translated.getConstants().entrySet() : Sets.intersection(entries, translated.getConstants().entrySet()); } checkState(entries != null); - Map constants = entries.stream() + Map constants = entries.stream() .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); - ImmutableList.Builder> localProperties = ImmutableList.builder(); + ImmutableList.Builder> localProperties = ImmutableList.builder(); if (node.getOrderingScheme().isPresent()) { node.getOrderingScheme().get().getOrderBy().stream() .map(column -> new SortingProperty<>(column, node.getOrderingScheme().get().getOrdering(column))) @@ -596,19 +597,20 @@ public ActualProperties visitFilter(FilterNode node, List inpu { ActualProperties properties = Iterables.getOnlyElement(inputProperties); - TupleDomain tupleDomain; + Map constants = new HashMap<>(properties.getConstants()); if (isExpression(node.getPredicate())) { - tupleDomain = ExpressionDomainTranslator.fromPredicate(metadata, session, castToExpression(node.getPredicate()), types).getTupleDomain(); + TupleDomain tupleDomain = ExpressionDomainTranslator.fromPredicate(metadata, session, castToExpression(node.getPredicate()), types).getTupleDomain(); + constants.putAll(extractFixedValuesToConstantExpressions(tupleDomain) + .map(values -> values.entrySet().stream() + .collect(toImmutableMap(entry -> toVariableReference(entry.getKey(), types), Map.Entry::getValue))) + .orElse(ImmutableMap.of())); } else { - tupleDomain = new RowExpressionDomainTranslator(metadata).fromPredicate(session.toConnectorSession(), node.getPredicate(), BASIC_COLUMN_EXTRACTOR) - .getTupleDomain() - .transform(column -> new Symbol(column.getName())); + TupleDomain tupleDomain = new RowExpressionDomainTranslator(metadata).fromPredicate(session.toConnectorSession(), node.getPredicate(), BASIC_COLUMN_EXTRACTOR).getTupleDomain(); + constants.putAll(extractFixedValuesToConstantExpressions(tupleDomain) + .orElse(ImmutableMap.of())); } - Map constants = new HashMap<>(properties.getConstants()); - constants.putAll(extractFixedValues(tupleDomain).orElse(ImmutableMap.of())); - return ActualProperties.builderFrom(properties) .constants(constants) .build(); @@ -619,14 +621,15 @@ public ActualProperties visitProject(ProjectNode node, List in { ActualProperties properties = Iterables.getOnlyElement(inputProperties); - Map identities = computeIdentityTranslations(node.getAssignments().getMap()); + Map identities = computeIdentityTranslations(node.getAssignments(), types); ActualProperties translatedProperties = properties.translate(column -> Optional.ofNullable(identities.get(column))); // Extract additional constants - Map constants = new HashMap<>(); - for (Map.Entry assignment : node.getAssignments().entrySet()) { + Map constants = new HashMap<>(); + for (Map.Entry assignment : node.getAssignments().entrySet()) { Expression expression = assignment.getValue(); + VariableReferenceExpression output = assignment.getKey(); Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList(), WarningCollector.NOOP); Type type = requireNonNull(expressionTypes.get(NodeRef.of(expression))); @@ -639,14 +642,14 @@ public ActualProperties visitProject(ProjectNode node, List in Object value = optimizer.optimize(NoOpSymbolResolver.INSTANCE); if (value instanceof SymbolReference) { - Symbol symbol = Symbol.from((SymbolReference) value); - NullableValue existingConstantValue = constants.get(symbol); + VariableReferenceExpression variable = toVariableReference(Symbol.from((SymbolReference) value), types); + ConstantExpression existingConstantValue = constants.get(variable); if (existingConstantValue != null) { - constants.put(assignment.getKey(), new NullableValue(type, value)); + constants.put(output, new ConstantExpression(value, type)); } } else if (!(value instanceof Expression)) { - constants.put(assignment.getKey(), new NullableValue(type, value)); + constants.put(output, new ConstantExpression(value, type)); } } constants.putAll(translatedProperties.getConstants()); @@ -680,7 +683,7 @@ public ActualProperties visitSample(SampleNode node, List inpu @Override public ActualProperties visitUnnest(UnnestNode node, List inputProperties) { - Set passThroughInputs = ImmutableSet.copyOf(node.getReplicateSymbols()); + Set passThroughInputs = ImmutableSet.copyOf(node.getReplicateVariables()); return Iterables.getOnlyElement(inputProperties).translate(column -> { if (passThroughInputs.contains(column)) { @@ -702,18 +705,18 @@ public ActualProperties visitValues(ValuesNode node, List cont public ActualProperties visitTableScan(TableScanNode node, List inputProperties) { TableLayout layout = metadata.getLayout(session, node.getTable()); - Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); + Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); ActualProperties.Builder properties = ActualProperties.builder(); // Globally constant assignments - Map globalConstants = new HashMap<>(); - extractFixedValues(node.getCurrentConstraint()).orElse(ImmutableMap.of()) + Map globalConstants = new HashMap<>(); + extractFixedValuesToConstantExpressions(node.getCurrentConstraint()).orElse(ImmutableMap.of()) .entrySet().stream() .filter(entry -> !entry.getValue().isNull()) .forEach(entry -> globalConstants.put(entry.getKey(), entry.getValue())); - Map symbolConstants = globalConstants.entrySet().stream() + Map symbolConstants = globalConstants.entrySet().stream() .filter(entry -> assignments.containsKey(entry.getKey())) .collect(toMap(entry -> assignments.get(entry.getKey()), Map.Entry::getValue)); properties.constants(symbolConstants); @@ -731,15 +734,15 @@ public ActualProperties visitTableScan(TableScanNode node, List assignments, Map constants) + private Global deriveGlobalProperties(TableLayout layout, Map assignments, Map constants) { - Optional> streamPartitioning = layout.getStreamPartitioningColumns() + Optional> streamPartitioning = layout.getStreamPartitioningColumns() .flatMap(columns -> translateToNonConstantSymbols(columns, assignments, constants)); if (planWithTableNodePartitioning(session) && layout.getTablePartitioning().isPresent()) { TablePartitioning tablePartitioning = layout.getTablePartitioning().get(); if (assignments.keySet().containsAll(tablePartitioning.getPartitioningColumns())) { - List arguments = tablePartitioning.getPartitioningColumns().stream() + List arguments = tablePartitioning.getPartitioningColumns().stream() .map(assignments::get) .collect(toImmutableList()); @@ -753,19 +756,19 @@ private Global deriveGlobalProperties(TableLayout layout, Map> translateToNonConstantSymbols( + private static Optional> translateToNonConstantSymbols( Set columnHandles, - Map assignments, - Map globalConstants) + Map assignments, + Map globalConstants) { // Strip off the constants from the partitioning columns (since those are not required for translation) Set constantsStrippedColumns = columnHandles.stream() .filter(column -> !globalConstants.containsKey(column)) .collect(toImmutableSet()); - ImmutableSet.Builder builder = ImmutableSet.builder(); + ImmutableSet.Builder builder = ImmutableSet.builder(); for (ColumnHandle column : constantsStrippedColumns) { - Symbol translated = assignments.get(column); + VariableReferenceExpression translated = assignments.get(column); if (translated == null) { return Optional.empty(); } @@ -774,17 +777,6 @@ private static Optional> translateToNonConstantSymbols( return Optional.of(ImmutableList.copyOf(builder.build())); } - - private static Map computeIdentityTranslations(Map assignments) - { - Map inputToOutput = new HashMap<>(); - for (Map.Entry assignment : assignments.entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { - inputToOutput.put(Symbol.from(assignment.getValue()), assignment.getKey()); - } - } - return inputToOutput; - } } static boolean spillPossible(Session session, JoinNode.Type joinType) @@ -805,7 +797,7 @@ static boolean spillPossible(Session session, JoinNode.Type joinType) } } - public static Optional filterIfMissing(Collection columns, Symbol column) + public static Optional filterIfMissing(Collection columns, VariableReferenceExpression column) { if (columns.contains(column)) { return Optional.of(column); @@ -817,7 +809,7 @@ public static Optional filterIfMissing(Collection columns, Symbo // Used to filter columns that are not exposed by join node // Or, if they are part of the equalities, to translate them // to the other symbol if that's exposed, instead. - public static Optional filterOrRewrite(Collection columns, Collection equalities, Symbol column) + public static Optional filterOrRewrite(Collection columns, List equalities, VariableReferenceExpression column) { // symbol is exposed directly, so no translation needed if (columns.contains(column)) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index e0ae7cc83e82c..99bf04876e44b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -18,8 +18,8 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.PartitioningScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.TypeProvider; @@ -79,18 +79,19 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.function.Function; import java.util.stream.Collectors; -import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractUnique; +import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractUniqueVariables; import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Sets.intersection; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; /** * Removes all computation that does is not referenced transitively from the root of the plan @@ -115,54 +116,61 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym requireNonNull(symbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); - return SimplePlanRewriter.rewriteWith(new Rewriter(), plan, ImmutableSet.of()); + return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator), plan, ImmutableSet.of()); } private static class Rewriter - extends SimplePlanRewriter> + extends SimplePlanRewriter> { + private final SymbolAllocator symbolAllocator; + + public Rewriter(SymbolAllocator symbolAllocator) + { + this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); + } + @Override - public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext> context) + public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext> context) { - return context.defaultRewrite(node, ImmutableSet.copyOf(node.getSource().getOutputSymbols())); + return context.defaultRewrite(node, ImmutableSet.copyOf(node.getSource().getOutputVariables())); } @Override - public PlanNode visitExchange(ExchangeNode node, RewriteContext> context) + public PlanNode visitExchange(ExchangeNode node, RewriteContext> context) { - Set expectedOutputSymbols = Sets.newHashSet(context.get()); - node.getPartitioningScheme().getHashColumn().ifPresent(expectedOutputSymbols::add); - node.getPartitioningScheme().getPartitioning().getColumns().stream() - .forEach(expectedOutputSymbols::add); - node.getOrderingScheme().ifPresent(orderingScheme -> expectedOutputSymbols.addAll(orderingScheme.getOrderBy())); + Set expectedOutputVariables = Sets.newHashSet(context.get()); + node.getPartitioningScheme().getHashColumn().ifPresent(expectedOutputVariables::add); + node.getPartitioningScheme().getPartitioning().getVariableReferences() + .forEach(expectedOutputVariables::add); + node.getOrderingScheme().ifPresent(orderingScheme -> expectedOutputVariables.addAll(orderingScheme.getOrderBy())); - List> inputsBySource = new ArrayList<>(node.getInputs().size()); + List> inputsBySource = new ArrayList<>(node.getInputs().size()); for (int i = 0; i < node.getInputs().size(); i++) { inputsBySource.add(new ArrayList<>()); } - List newOutputSymbols = new ArrayList<>(node.getOutputSymbols().size()); - for (int i = 0; i < node.getOutputSymbols().size(); i++) { - Symbol outputSymbol = node.getOutputSymbols().get(i); - if (expectedOutputSymbols.contains(outputSymbol)) { - newOutputSymbols.add(outputSymbol); + List newOutputVariables = new ArrayList<>(node.getOutputVariables().size()); + for (int i = 0; i < node.getOutputVariables().size(); i++) { + VariableReferenceExpression outputVariable = node.getOutputVariables().get(i); + if (expectedOutputVariables.contains(outputVariable)) { + newOutputVariables.add(outputVariable); for (int source = 0; source < node.getInputs().size(); source++) { inputsBySource.get(source).add(node.getInputs().get(source).get(i)); } } } - // newOutputSymbols contains all partition, sort and hash symbols so simply swap the output layout + // newOutputVariables contains all partition, sort and hash variables so simply swap the output layout PartitioningScheme partitioningScheme = new PartitioningScheme( node.getPartitioningScheme().getPartitioning(), - newOutputSymbols, + newOutputVariables, node.getPartitioningScheme().getHashColumn(), node.getPartitioningScheme().isReplicateNullsAndAny(), node.getPartitioningScheme().getBucketToPartition()); ImmutableList.Builder rewrittenSources = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { - ImmutableSet.Builder expectedInputs = ImmutableSet.builder() + ImmutableSet.Builder expectedInputs = ImmutableSet.builder() .addAll(inputsBySource.get(i)); rewrittenSources.add(context.rewrite( @@ -181,70 +189,70 @@ public PlanNode visitExchange(ExchangeNode node, RewriteContext> con } @Override - public PlanNode visitJoin(JoinNode node, RewriteContext> context) + public PlanNode visitJoin(JoinNode node, RewriteContext> context) { - Set expectedFilterInputs = new HashSet<>(); + Set expectedFilterInputs = new HashSet<>(); if (node.getFilter().isPresent()) { - expectedFilterInputs = ImmutableSet.builder() - .addAll(SymbolsExtractor.extractUnique(castToExpression(node.getFilter().get()))) + expectedFilterInputs = ImmutableSet.builder() + .addAll(SymbolsExtractor.extractUniqueVariable(castToExpression(node.getFilter().get()), symbolAllocator.getTypes())) .addAll(context.get()) .build(); } - ImmutableSet.Builder leftInputsBuilder = ImmutableSet.builder(); + ImmutableSet.Builder leftInputsBuilder = ImmutableSet.builder(); leftInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getLeft)); - if (node.getLeftHashSymbol().isPresent()) { - leftInputsBuilder.add(node.getLeftHashSymbol().get()); + if (node.getLeftHashVariable().isPresent()) { + leftInputsBuilder.add(node.getLeftHashVariable().get()); } leftInputsBuilder.addAll(expectedFilterInputs); - Set leftInputs = leftInputsBuilder.build(); + Set leftInputs = leftInputsBuilder.build(); - ImmutableSet.Builder rightInputsBuilder = ImmutableSet.builder(); + ImmutableSet.Builder rightInputsBuilder = ImmutableSet.builder(); rightInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight)); - if (node.getRightHashSymbol().isPresent()) { - rightInputsBuilder.add(node.getRightHashSymbol().get()); + if (node.getRightHashVariable().isPresent()) { + rightInputsBuilder.add(node.getRightHashVariable().get()); } rightInputsBuilder.addAll(expectedFilterInputs); - Set rightInputs = rightInputsBuilder.build(); + Set rightInputs = rightInputsBuilder.build(); PlanNode left = context.rewrite(node.getLeft(), leftInputs); PlanNode right = context.rewrite(node.getRight(), rightInputs); - List outputSymbols; + List outputVariables; if (node.isCrossJoin()) { // do not prune nested joins output since it is not supported // TODO: remove this "if" branch when output symbols selection is supported by nested loop join - outputSymbols = ImmutableList.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + outputVariables = ImmutableList.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(); } else { - outputSymbols = node.getOutputSymbols().stream() - .filter(context.get()::contains) + outputVariables = node.getOutputVariables().stream() + .filter(variable -> context.get().contains(variable)) .distinct() .collect(toImmutableList()); } - return new JoinNode(node.getId(), node.getType(), left, right, node.getCriteria(), outputSymbols, node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); + return new JoinNode(node.getId(), node.getType(), left, right, node.getCriteria(), outputVariables, node.getFilter(), node.getLeftHashVariable(), node.getRightHashVariable(), node.getDistributionType()); } @Override - public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext> context) + public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext> context) { - ImmutableSet.Builder sourceInputsBuilder = ImmutableSet.builder(); - sourceInputsBuilder.addAll(context.get()).add(node.getSourceJoinSymbol()); - if (node.getSourceHashSymbol().isPresent()) { - sourceInputsBuilder.add(node.getSourceHashSymbol().get()); + ImmutableSet.Builder sourceInputsBuilder = ImmutableSet.builder(); + sourceInputsBuilder.addAll(context.get()).add(node.getSourceJoinVariable()); + if (node.getSourceHashVariable().isPresent()) { + sourceInputsBuilder.add(node.getSourceHashVariable().get()); } - Set sourceInputs = sourceInputsBuilder.build(); + Set sourceInputs = sourceInputsBuilder.build(); - ImmutableSet.Builder filteringSourceInputBuilder = ImmutableSet.builder(); - filteringSourceInputBuilder.add(node.getFilteringSourceJoinSymbol()); - if (node.getFilteringSourceHashSymbol().isPresent()) { - filteringSourceInputBuilder.add(node.getFilteringSourceHashSymbol().get()); + ImmutableSet.Builder filteringSourceInputBuilder = ImmutableSet.builder(); + filteringSourceInputBuilder.add(node.getFilteringSourceJoinVariable()); + if (node.getFilteringSourceHashVariable().isPresent()) { + filteringSourceInputBuilder.add(node.getFilteringSourceHashVariable().get()); } - Set filteringSourceInputs = filteringSourceInputBuilder.build(); + Set filteringSourceInputs = filteringSourceInputBuilder.build(); PlanNode source = context.rewrite(node.getSource(), sourceInputs); PlanNode filteringSource = context.rewrite(node.getFilteringSource(), filteringSourceInputs); @@ -252,106 +260,106 @@ public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext> con return new SemiJoinNode(node.getId(), source, filteringSource, - node.getSourceJoinSymbol(), - node.getFilteringSourceJoinSymbol(), + node.getSourceJoinVariable(), + node.getFilteringSourceJoinVariable(), node.getSemiJoinOutput(), - node.getSourceHashSymbol(), - node.getFilteringSourceHashSymbol(), + node.getSourceHashVariable(), + node.getFilteringSourceHashVariable(), node.getDistributionType()); } @Override - public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext> context) + public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext> context) { - Set filterSymbols; + Set filterSymbols; if (isExpression(node.getFilter())) { - filterSymbols = SymbolsExtractor.extractUnique(castToExpression(node.getFilter())); + filterSymbols = SymbolsExtractor.extractUniqueVariable(castToExpression(node.getFilter()), symbolAllocator.getTypes()); } else { - filterSymbols = SymbolsExtractor.extractUnique(node.getFilter()); + filterSymbols = SymbolsExtractor.extractUniqueVariable(node.getFilter()); } - Set requiredInputs = ImmutableSet.builder() + Set requiredInputs = ImmutableSet.builder() .addAll(filterSymbols) .addAll(context.get()) .build(); - ImmutableSet.Builder leftInputs = ImmutableSet.builder(); - node.getLeftPartitionSymbol().map(leftInputs::add); + ImmutableSet.Builder leftInputs = ImmutableSet.builder(); + node.getLeftPartitionVariable().map(leftInputs::add); - ImmutableSet.Builder rightInputs = ImmutableSet.builder(); - node.getRightPartitionSymbol().map(rightInputs::add); + ImmutableSet.Builder rightInputs = ImmutableSet.builder(); + node.getRightPartitionVariable().map(rightInputs::add); PlanNode left = context.rewrite(node.getLeft(), leftInputs.addAll(requiredInputs).build()); PlanNode right = context.rewrite(node.getRight(), rightInputs.addAll(requiredInputs).build()); - List outputSymbols = node.getOutputSymbols().stream() + List outputVariables = node.getOutputVariables().stream() .filter(context.get()::contains) .distinct() .collect(toImmutableList()); - return new SpatialJoinNode(node.getId(), node.getType(), left, right, outputSymbols, node.getFilter(), node.getLeftPartitionSymbol(), node.getRightPartitionSymbol(), node.getKdbTree()); + return new SpatialJoinNode(node.getId(), node.getType(), left, right, outputVariables, node.getFilter(), node.getLeftPartitionVariable(), node.getRightPartitionVariable(), node.getKdbTree()); } @Override - public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext> context) + public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext> context) { - ImmutableSet.Builder probeInputsBuilder = ImmutableSet.builder(); + ImmutableSet.Builder probeInputsBuilder = ImmutableSet.builder(); probeInputsBuilder.addAll(context.get()) .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getProbe)); - if (node.getProbeHashSymbol().isPresent()) { - probeInputsBuilder.add(node.getProbeHashSymbol().get()); + if (node.getProbeHashVariable().isPresent()) { + probeInputsBuilder.add(node.getProbeHashVariable().get()); } - Set probeInputs = probeInputsBuilder.build(); + Set probeInputs = probeInputsBuilder.build(); - ImmutableSet.Builder indexInputBuilder = ImmutableSet.builder(); + ImmutableSet.Builder indexInputBuilder = ImmutableSet.builder(); indexInputBuilder.addAll(context.get()) .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getIndex)); - if (node.getIndexHashSymbol().isPresent()) { - indexInputBuilder.add(node.getIndexHashSymbol().get()); + if (node.getIndexHashVariable().isPresent()) { + indexInputBuilder.add(node.getIndexHashVariable().get()); } - Set indexInputs = indexInputBuilder.build(); + Set indexInputs = indexInputBuilder.build(); PlanNode probeSource = context.rewrite(node.getProbeSource(), probeInputs); PlanNode indexSource = context.rewrite(node.getIndexSource(), indexInputs); - return new IndexJoinNode(node.getId(), node.getType(), probeSource, indexSource, node.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol()); + return new IndexJoinNode(node.getId(), node.getType(), probeSource, indexSource, node.getCriteria(), node.getProbeHashVariable(), node.getIndexHashVariable()); } @Override - public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext> context) + public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext> context) { - List newOutputSymbols = node.getOutputSymbols().stream() + List newOutputVariables = node.getOutputVariables().stream() .filter(context.get()::contains) .collect(toImmutableList()); - Set newLookupSymbols = node.getLookupSymbols().stream() + Set newLookupVariables = node.getLookupVariables().stream() .filter(context.get()::contains) .collect(toImmutableSet()); - Map newAssignments = newOutputSymbols.stream() - .collect(Collectors.toMap(Function.identity(), node.getAssignments()::get)); + Map newAssignments = newOutputVariables.stream() + .collect(toImmutableMap(identity(), node.getAssignments()::get)); - return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), newLookupSymbols, newOutputSymbols, newAssignments, node.getCurrentConstraint()); + return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), newLookupVariables, newOutputVariables, newAssignments, node.getCurrentConstraint()); } @Override - public PlanNode visitAggregation(AggregationNode node, RewriteContext> context) + public PlanNode visitAggregation(AggregationNode node, RewriteContext> context) { - ImmutableSet.Builder expectedInputs = ImmutableSet.builder() + ImmutableSet.Builder expectedInputs = ImmutableSet.builder() .addAll(node.getGroupingKeys()); - if (node.getHashSymbol().isPresent()) { - expectedInputs.add(node.getHashSymbol().get()); + if (node.getHashVariable().isPresent()) { + expectedInputs.add(node.getHashVariable().get()); } - ImmutableMap.Builder aggregations = ImmutableMap.builder(); - for (Map.Entry entry : node.getAggregations().entrySet()) { - Symbol symbol = entry.getKey(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : node.getAggregations().entrySet()) { + VariableReferenceExpression variable = entry.getKey(); - if (context.get().contains(symbol)) { + if (context.get().contains(variable)) { Aggregation aggregation = entry.getValue(); - expectedInputs.addAll(extractUnique(aggregation)); + expectedInputs.addAll(extractUniqueVariables(aggregation, symbolAllocator.getTypes())); aggregation.getMask().ifPresent(expectedInputs::add); - aggregations.put(symbol, aggregation); + aggregations.put(variable, aggregation); } } @@ -363,14 +371,14 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext> context) + public PlanNode visitWindow(WindowNode node, RewriteContext> context) { - ImmutableSet.Builder expectedInputs = ImmutableSet.builder() + ImmutableSet.Builder expectedInputs = ImmutableSet.builder() .addAll(context.get()) .addAll(node.getPartitionBy()); @@ -387,24 +395,23 @@ public PlanNode visitWindow(WindowNode node, RewriteContext> context } } - if (node.getHashSymbol().isPresent()) { - expectedInputs.add(node.getHashSymbol().get()); + if (node.getHashVariable().isPresent()) { + expectedInputs.add(node.getHashVariable().get()); } - ImmutableMap.Builder functionsBuilder = ImmutableMap.builder(); - for (Map.Entry entry : node.getWindowFunctions().entrySet()) { - Symbol symbol = entry.getKey(); + ImmutableMap.Builder functionsBuilder = ImmutableMap.builder(); + for (Map.Entry entry : node.getWindowFunctions().entrySet()) { + VariableReferenceExpression variable = entry.getKey(); WindowNode.Function function = entry.getValue(); - - if (context.get().contains(symbol)) { - expectedInputs.addAll(WindowNodeUtil.extractWindowFunctionUnique(function)); - functionsBuilder.put(symbol, entry.getValue()); + if (context.get().contains(variable)) { + expectedInputs.addAll(WindowNodeUtil.extractWindowFunctionUniqueVariables(function, symbolAllocator.getTypes())); + functionsBuilder.put(variable, entry.getValue()); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); - Map functions = functionsBuilder.build(); + Map functions = functionsBuilder.build(); if (functions.size() == 0) { return source; @@ -415,20 +422,20 @@ public PlanNode visitWindow(WindowNode node, RewriteContext> context source, node.getSpecification(), functions, - node.getHashSymbol(), + node.getHashVariable(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix()); } @Override - public PlanNode visitTableScan(TableScanNode node, RewriteContext> context) + public PlanNode visitTableScan(TableScanNode node, RewriteContext> context) { - List newOutputs = node.getOutputSymbols().stream() + List newOutputs = node.getOutputVariables().stream() .filter(context.get()::contains) .collect(toImmutableList()); - Map newAssignments = newOutputs.stream() - .collect(Collectors.toMap(Function.identity(), node.getAssignments()::get)); + Map newAssignments = newOutputs.stream() + .collect(Collectors.toMap(identity(), node.getAssignments()::get)); return new TableScanNode( node.getId(), @@ -440,10 +447,10 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext> c } @Override - public PlanNode visitFilter(FilterNode node, RewriteContext> context) + public PlanNode visitFilter(FilterNode node, RewriteContext> context) { - Set expectedInputs = ImmutableSet.builder() - .addAll(SymbolsExtractor.extractUnique(castToExpression(node.getPredicate()))) + Set expectedInputs = ImmutableSet.builder() + .addAll(SymbolsExtractor.extractUniqueVariable(castToExpression(node.getPredicate()), symbolAllocator.getTypes())) .addAll(context.get()) .build(); @@ -453,22 +460,22 @@ public PlanNode visitFilter(FilterNode node, RewriteContext> context } @Override - public PlanNode visitGroupId(GroupIdNode node, RewriteContext> context) + public PlanNode visitGroupId(GroupIdNode node, RewriteContext> context) { - ImmutableSet.Builder expectedInputs = ImmutableSet.builder(); + ImmutableSet.Builder expectedInputs = ImmutableSet.builder(); - List newAggregationArguments = node.getAggregationArguments().stream() + List newAggregationArguments = node.getAggregationArguments().stream() .filter(context.get()::contains) .collect(Collectors.toList()); expectedInputs.addAll(newAggregationArguments); - ImmutableList.Builder> newGroupingSets = ImmutableList.builder(); - Map newGroupingMapping = new HashMap<>(); + ImmutableList.Builder> newGroupingSets = ImmutableList.builder(); + Map newGroupingMapping = new HashMap<>(); - for (List groupingSet : node.getGroupingSets()) { - ImmutableList.Builder newGroupingSet = ImmutableList.builder(); + for (List groupingSet : node.getGroupingSets()) { + ImmutableList.Builder newGroupingSet = ImmutableList.builder(); - for (Symbol output : groupingSet) { + for (VariableReferenceExpression output : groupingSet) { if (context.get().contains(output)) { newGroupingSet.add(output); newGroupingMapping.putIfAbsent(output, node.getGroupingColumns().get(output)); @@ -479,60 +486,60 @@ public PlanNode visitGroupId(GroupIdNode node, RewriteContext> conte } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); - return new GroupIdNode(node.getId(), source, newGroupingSets.build(), newGroupingMapping, newAggregationArguments, node.getGroupIdSymbol()); + return new GroupIdNode(node.getId(), source, newGroupingSets.build(), newGroupingMapping, newAggregationArguments, node.getGroupIdVariable()); } @Override - public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext> context) + public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext> context) { - if (!context.get().contains(node.getMarkerSymbol())) { + if (!context.get().contains(node.getMarkerVariable())) { return context.rewrite(node.getSource(), context.get()); } - ImmutableSet.Builder expectedInputs = ImmutableSet.builder() - .addAll(node.getDistinctSymbols()) + ImmutableSet.Builder expectedInputs = ImmutableSet.builder() + .addAll(node.getDistinctVariables()) .addAll(context.get().stream() - .filter(symbol -> !symbol.equals(node.getMarkerSymbol())) + .filter(variable -> !variable.equals(node.getMarkerVariable())) .collect(toImmutableList())); - if (node.getHashSymbol().isPresent()) { - expectedInputs.add(node.getHashSymbol().get()); + if (node.getHashVariable().isPresent()) { + expectedInputs.add(node.getHashVariable().get()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); - return new MarkDistinctNode(node.getId(), source, node.getMarkerSymbol(), node.getDistinctSymbols(), node.getHashSymbol()); + return new MarkDistinctNode(node.getId(), source, node.getMarkerVariable(), node.getDistinctVariables(), node.getHashVariable()); } @Override - public PlanNode visitUnnest(UnnestNode node, RewriteContext> context) + public PlanNode visitUnnest(UnnestNode node, RewriteContext> context) { - List replicateSymbols = node.getReplicateSymbols().stream() + List replicateVariables = node.getReplicateVariables().stream() .filter(context.get()::contains) .collect(toImmutableList()); - Optional ordinalitySymbol = node.getOrdinalitySymbol(); - if (ordinalitySymbol.isPresent() && !context.get().contains(ordinalitySymbol.get())) { - ordinalitySymbol = Optional.empty(); + Optional ordinalityVariable = node.getOrdinalityVariable(); + if (ordinalityVariable.isPresent() && !context.get().contains(ordinalityVariable.get())) { + ordinalityVariable = Optional.empty(); } - Map> unnestSymbols = node.getUnnestSymbols(); - ImmutableSet.Builder expectedInputs = ImmutableSet.builder() - .addAll(replicateSymbols) - .addAll(unnestSymbols.keySet()); + Map> unnestVariables = node.getUnnestVariables(); + ImmutableSet.Builder expectedInputs = ImmutableSet.builder() + .addAll(replicateVariables) + .addAll(unnestVariables.keySet()); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); - return new UnnestNode(node.getId(), source, replicateSymbols, unnestSymbols, ordinalitySymbol); + return new UnnestNode(node.getId(), source, replicateVariables, unnestVariables, ordinalityVariable); } @Override - public PlanNode visitProject(ProjectNode node, RewriteContext> context) + public PlanNode visitProject(ProjectNode node, RewriteContext> context) { - ImmutableSet.Builder expectedInputs = ImmutableSet.builder(); + ImmutableSet.Builder expectedInputs = ImmutableSet.builder(); Assignments.Builder builder = Assignments.builder(); - node.getAssignments().forEach((symbol, expression) -> { - if (context.get().contains(symbol)) { - expectedInputs.addAll(SymbolsExtractor.extractUnique(expression)); - builder.put(symbol, expression); + node.getAssignments().forEach((variable, expression) -> { + if (context.get().contains(variable)) { + expectedInputs.addAll(SymbolsExtractor.extractUniqueVariable(expression, symbolAllocator.getTypes())); + builder.put(variable, expression); } }); @@ -542,40 +549,40 @@ public PlanNode visitProject(ProjectNode node, RewriteContext> conte } @Override - public PlanNode visitOutput(OutputNode node, RewriteContext> context) + public PlanNode visitOutput(OutputNode node, RewriteContext> context) { - Set expectedInputs = ImmutableSet.copyOf(node.getOutputSymbols()); + Set expectedInputs = ImmutableSet.copyOf(node.getOutputVariables()); PlanNode source = context.rewrite(node.getSource(), expectedInputs); - return new OutputNode(node.getId(), source, node.getColumnNames(), node.getOutputSymbols()); + return new OutputNode(node.getId(), source, node.getColumnNames(), node.getOutputVariables()); } @Override - public PlanNode visitLimit(LimitNode node, RewriteContext> context) + public PlanNode visitLimit(LimitNode node, RewriteContext> context) { - ImmutableSet.Builder expectedInputs = ImmutableSet.builder() + ImmutableSet.Builder expectedInputs = ImmutableSet.builder() .addAll(context.get()); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new LimitNode(node.getId(), source, node.getCount(), node.isPartial()); } @Override - public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext> context) + public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext> context) { - Set expectedInputs; - if (node.getHashSymbol().isPresent()) { - expectedInputs = ImmutableSet.copyOf(concat(node.getDistinctSymbols(), ImmutableList.of(node.getHashSymbol().get()))); + Set expectedInputs; + if (node.getHashVariable().isPresent()) { + expectedInputs = ImmutableSet.copyOf(concat(node.getDistinctVariables(), ImmutableList.of(node.getHashVariable().get()))); } else { - expectedInputs = ImmutableSet.copyOf(node.getDistinctSymbols()); + expectedInputs = ImmutableSet.copyOf(node.getDistinctVariables()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs); - return new DistinctLimitNode(node.getId(), source, node.getLimit(), node.isPartial(), node.getDistinctSymbols(), node.getHashSymbol()); + return new DistinctLimitNode(node.getId(), source, node.getLimit(), node.isPartial(), node.getDistinctVariables(), node.getHashVariable()); } @Override - public PlanNode visitTopN(TopNNode node, RewriteContext> context) + public PlanNode visitTopN(TopNNode node, RewriteContext> context) { - ImmutableSet.Builder expectedInputs = ImmutableSet.builder() + ImmutableSet.Builder expectedInputs = ImmutableSet.builder() .addAll(context.get()) .addAll(node.getOrderingScheme().getOrderBy()); @@ -585,47 +592,47 @@ public PlanNode visitTopN(TopNNode node, RewriteContext> context) } @Override - public PlanNode visitRowNumber(RowNumberNode node, RewriteContext> context) + public PlanNode visitRowNumber(RowNumberNode node, RewriteContext> context) { - ImmutableSet.Builder inputsBuilder = ImmutableSet.builder(); - ImmutableSet.Builder expectedInputs = inputsBuilder + ImmutableSet.Builder inputsBuilder = ImmutableSet.builder(); + ImmutableSet.Builder expectedInputs = inputsBuilder .addAll(context.get()) .addAll(node.getPartitionBy()); - if (node.getHashSymbol().isPresent()) { - inputsBuilder.add(node.getHashSymbol().get()); + if (node.getHashVariable().isPresent()) { + inputsBuilder.add(node.getHashVariable().get()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); - return new RowNumberNode(node.getId(), source, node.getPartitionBy(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.getHashSymbol()); + return new RowNumberNode(node.getId(), source, node.getPartitionBy(), node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), node.getHashVariable()); } @Override - public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext> context) + public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext> context) { - ImmutableSet.Builder expectedInputs = ImmutableSet.builder() + ImmutableSet.Builder expectedInputs = ImmutableSet.builder() .addAll(context.get()) .addAll(node.getPartitionBy()) .addAll(node.getOrderingScheme().getOrderBy()); - if (node.getHashSymbol().isPresent()) { - expectedInputs.add(node.getHashSymbol().get()); + if (node.getHashVariable().isPresent()) { + expectedInputs.add(node.getHashVariable().get()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new TopNRowNumberNode(node.getId(), source, node.getSpecification(), - node.getRowNumberSymbol(), + node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), node.isPartial(), - node.getHashSymbol()); + node.getHashVariable()); } @Override - public PlanNode visitSort(SortNode node, RewriteContext> context) + public PlanNode visitSort(SortNode node, RewriteContext> context) { - Set expectedInputs = ImmutableSet.copyOf(concat(context.get(), node.getOrderingScheme().getOrderBy())); + Set expectedInputs = ImmutableSet.copyOf(concat(context.get(), node.getOrderingScheme().getOrderBy())); PlanNode source = context.rewrite(node.getSource(), expectedInputs); @@ -633,28 +640,28 @@ public PlanNode visitSort(SortNode node, RewriteContext> context) } @Override - public PlanNode visitTableWriter(TableWriterNode node, RewriteContext> context) + public PlanNode visitTableWriter(TableWriterNode node, RewriteContext> context) { - ImmutableSet.Builder expectedInputs = ImmutableSet.builder() + ImmutableSet.Builder expectedInputs = ImmutableSet.builder() .addAll(node.getColumns()); if (node.getPartitioningScheme().isPresent()) { PartitioningScheme partitioningScheme = node.getPartitioningScheme().get(); - partitioningScheme.getPartitioning().getColumns().forEach(expectedInputs::add); + partitioningScheme.getPartitioning().getVariableReferences().forEach(expectedInputs::add); partitioningScheme.getHashColumn().ifPresent(expectedInputs::add); } if (node.getStatisticsAggregation().isPresent()) { StatisticAggregations aggregations = node.getStatisticsAggregation().get(); - expectedInputs.addAll(aggregations.getGroupingSymbols()); - aggregations.getAggregations().values().forEach(aggregation -> expectedInputs.addAll(extractUnique(aggregation))); + expectedInputs.addAll(aggregations.getGroupingVariables()); + aggregations.getAggregations().values().forEach(aggregation -> expectedInputs.addAll(extractUniqueVariables(aggregation, symbolAllocator.getTypes()))); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new TableWriterNode( node.getId(), source, node.getTarget(), - node.getRowCountSymbol(), - node.getFragmentSymbol(), - node.getTableCommitContextSymbol(), + node.getRowCountVariable(), + node.getFragmentVariable(), + node.getTableCommitContextVariable(), node.getColumns(), node.getColumnNames(), node.getPartitioningScheme(), @@ -663,82 +670,84 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext> context) + public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext> context) { - PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputSymbols())); + PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputVariables())); return new StatisticsWriterNode( node.getId(), source, node.getTarget(), - node.getRowCountSymbol(), + node.getRowCountVariable(), node.isRowCountEnabled(), node.getDescriptor()); } @Override - public PlanNode visitTableFinish(TableFinishNode node, RewriteContext> context) + public PlanNode visitTableFinish(TableFinishNode node, RewriteContext> context) { - PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputSymbols())); + PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputVariables())); return new TableFinishNode( node.getId(), source, node.getTarget(), - node.getRowCountSymbol(), + node.getRowCountVariable(), node.getStatisticsAggregation(), node.getStatisticsAggregationDescriptor()); } @Override - public PlanNode visitDelete(DeleteNode node, RewriteContext> context) + public PlanNode visitDelete(DeleteNode node, RewriteContext> context) { PlanNode source = context.rewrite(node.getSource(), ImmutableSet.of(node.getRowId())); - return new DeleteNode(node.getId(), source, node.getTarget(), node.getRowId(), node.getOutputSymbols()); + return new DeleteNode(node.getId(), source, node.getTarget(), node.getRowId(), node.getOutputVariables()); } @Override - public PlanNode visitUnion(UnionNode node, RewriteContext> context) + public PlanNode visitUnion(UnionNode node, RewriteContext> context) { - ListMultimap rewrittenSymbolMapping = rewriteSetOperationSymbolMapping(node, context); - ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenSymbolMapping); - return new UnionNode(node.getId(), rewrittenSubPlans, rewrittenSymbolMapping, ImmutableList.copyOf(rewrittenSymbolMapping.keySet())); + ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context); + ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping); + return new UnionNode(node.getId(), rewrittenSubPlans, rewrittenVariableMapping); } @Override - public PlanNode visitIntersect(IntersectNode node, RewriteContext> context) + public PlanNode visitIntersect(IntersectNode node, RewriteContext> context) { - ListMultimap rewrittenSymbolMapping = rewriteSetOperationSymbolMapping(node, context); - ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenSymbolMapping); - return new IntersectNode(node.getId(), rewrittenSubPlans, rewrittenSymbolMapping, ImmutableList.copyOf(rewrittenSymbolMapping.keySet())); + ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context); + ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping); + return new IntersectNode(node.getId(), rewrittenSubPlans, rewrittenVariableMapping); } @Override - public PlanNode visitExcept(ExceptNode node, RewriteContext> context) + public PlanNode visitExcept(ExceptNode node, RewriteContext> context) { - ListMultimap rewrittenSymbolMapping = rewriteSetOperationSymbolMapping(node, context); - ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenSymbolMapping); - return new ExceptNode(node.getId(), rewrittenSubPlans, rewrittenSymbolMapping, ImmutableList.copyOf(rewrittenSymbolMapping.keySet())); + ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context); + ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping); + return new ExceptNode(node.getId(), rewrittenSubPlans, rewrittenVariableMapping); } - private ListMultimap rewriteSetOperationSymbolMapping(SetOperationNode node, RewriteContext> context) + private ListMultimap rewriteSetOperationVariableMapping(SetOperationNode node, RewriteContext> context) { - // Find out which output symbols we need to keep - ImmutableListMultimap.Builder rewrittenSymbolMappingBuilder = ImmutableListMultimap.builder(); - for (Symbol symbol : node.getOutputSymbols()) { - if (context.get().contains(symbol)) { - rewrittenSymbolMappingBuilder.putAll(symbol, node.getSymbolMapping().get(symbol)); + // Find out which output variables we need to keep + ImmutableListMultimap.Builder rewrittenVariableMappingBuilder = ImmutableListMultimap.builder(); + for (VariableReferenceExpression variable : node.getOutputVariables()) { + if (context.get().contains(variable)) { + rewrittenVariableMappingBuilder.putAll( + variable, + node.getVariableMapping().get(variable)); } } - return rewrittenSymbolMappingBuilder.build(); + return rewrittenVariableMappingBuilder.build(); } - private ImmutableList rewriteSetOperationSubPlans(SetOperationNode node, RewriteContext> context, ListMultimap rewrittenSymbolMapping) + private ImmutableList rewriteSetOperationSubPlans(SetOperationNode node, RewriteContext> context, ListMultimap rewrittenVariableMapping) { // Find the corresponding input symbol to the remaining output symbols and prune the subplans ImmutableList.Builder rewrittenSubPlans = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { - ImmutableSet.Builder expectedInputSymbols = ImmutableSet.builder(); - for (Collection symbols : rewrittenSymbolMapping.asMap().values()) { - expectedInputSymbols.add(Iterables.get(symbols, i)); + ImmutableSet.Builder expectedInputSymbols = ImmutableSet.builder(); + for (Collection variables : rewrittenVariableMapping.asMap().values()) { + expectedInputSymbols.add(Iterables.get(variables, i)); } rewrittenSubPlans.add(context.rewrite(node.getSources().get(i), expectedInputSymbols.build())); } @@ -746,20 +755,20 @@ private ImmutableList rewriteSetOperationSubPlans(SetOperationNode nod } @Override - public PlanNode visitValues(ValuesNode node, RewriteContext> context) + public PlanNode visitValues(ValuesNode node, RewriteContext> context) { - ImmutableList.Builder rewrittenOutputSymbolsBuilder = ImmutableList.builder(); + ImmutableList.Builder rewrittenOutputVariablesBuilder = ImmutableList.builder(); ImmutableList.Builder> rowBuildersBuilder = ImmutableList.builder(); // Initialize builder for each row for (int i = 0; i < node.getRows().size(); i++) { rowBuildersBuilder.add(ImmutableList.builder()); } ImmutableList> rowBuilders = rowBuildersBuilder.build(); - for (int i = 0; i < node.getOutputSymbols().size(); i++) { - Symbol outputSymbol = node.getOutputSymbols().get(i); + for (int i = 0; i < node.getOutputVariables().size(); i++) { + VariableReferenceExpression outputVariable = node.getOutputVariables().get(i); // If output symbol is used - if (context.get().contains(outputSymbol)) { - rewrittenOutputSymbolsBuilder.add(outputSymbol); + if (context.get().contains(outputVariable)) { + rewrittenOutputVariablesBuilder.add(outputVariable); // Add the value of the output symbol for each row for (int j = 0; j < node.getRows().size(); j++) { rowBuilders.get(j).add(node.getRows().get(j).get(i)); @@ -769,80 +778,80 @@ public PlanNode visitValues(ValuesNode node, RewriteContext> context List> rewrittenRows = rowBuilders.stream() .map(ImmutableList.Builder::build) .collect(toImmutableList()); - return new ValuesNode(node.getId(), rewrittenOutputSymbolsBuilder.build(), rewrittenRows); + return new ValuesNode(node.getId(), rewrittenOutputVariablesBuilder.build(), rewrittenRows); } @Override - public PlanNode visitApply(ApplyNode node, RewriteContext> context) + public PlanNode visitApply(ApplyNode node, RewriteContext> context) { // remove unused apply nodes - if (intersection(node.getSubqueryAssignments().getSymbols(), context.get()).isEmpty()) { + if (intersection(node.getSubqueryAssignments().getVariables(), context.get()).isEmpty()) { return context.rewrite(node.getInput(), context.get()); } // extract symbols required subquery plan - ImmutableSet.Builder subqueryAssignmentsSymbolsBuilder = ImmutableSet.builder(); + ImmutableSet.Builder subqueryAssignmentsVariablesBuilder = ImmutableSet.builder(); Assignments.Builder subqueryAssignments = Assignments.builder(); - for (Map.Entry entry : node.getSubqueryAssignments().getMap().entrySet()) { - Symbol output = entry.getKey(); + for (Map.Entry entry : node.getSubqueryAssignments().getMap().entrySet()) { + VariableReferenceExpression output = entry.getKey(); Expression expression = entry.getValue(); if (context.get().contains(output)) { - subqueryAssignmentsSymbolsBuilder.addAll(SymbolsExtractor.extractUnique(expression)); + subqueryAssignmentsVariablesBuilder.addAll(SymbolsExtractor.extractUniqueVariable(expression, symbolAllocator.getTypes())); subqueryAssignments.put(output, expression); } } - Set subqueryAssignmentsSymbols = subqueryAssignmentsSymbolsBuilder.build(); - PlanNode subquery = context.rewrite(node.getSubquery(), subqueryAssignmentsSymbols); + Set subqueryAssignmentsVariables = subqueryAssignmentsVariablesBuilder.build(); + PlanNode subquery = context.rewrite(node.getSubquery(), subqueryAssignmentsVariables); // prune not used correlation symbols - Set subquerySymbols = SymbolsExtractor.extractUnique(subquery); - List newCorrelation = node.getCorrelation().stream() + Set subquerySymbols = SymbolsExtractor.extractUniqueVariable(subquery, symbolAllocator.getTypes()); + List newCorrelation = node.getCorrelation().stream() .filter(subquerySymbols::contains) .collect(toImmutableList()); - Set inputContext = ImmutableSet.builder() + Set inputContext = ImmutableSet.builder() .addAll(context.get()) .addAll(newCorrelation) - .addAll(subqueryAssignmentsSymbols) // need to include those: e.g: "expr" from "expr IN (SELECT 1)" + .addAll(subqueryAssignmentsVariables) // need to include those: e.g: "expr" from "expr IN (SELECT 1)" .build(); PlanNode input = context.rewrite(node.getInput(), inputContext); return new ApplyNode(node.getId(), input, subquery, subqueryAssignments.build(), newCorrelation, node.getOriginSubqueryError()); } @Override - public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext> context) + public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext> context) { - if (!context.get().contains(node.getIdColumn())) { + if (!context.get().contains(node.getIdVariable())) { return context.rewrite(node.getSource(), context.get()); } return context.defaultRewrite(node, context.get()); } @Override - public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext> context) + public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext> context) { PlanNode subquery = context.rewrite(node.getSubquery(), context.get()); // remove unused lateral nodes - if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), context.get()).isEmpty() && isScalar(subquery)) { + if (intersection(ImmutableSet.copyOf(subquery.getOutputVariables()), context.get()).isEmpty() && isScalar(subquery)) { return context.rewrite(node.getInput(), context.get()); } // prune not used correlation symbols - Set subquerySymbols = SymbolsExtractor.extractUnique(subquery); - List newCorrelation = node.getCorrelation().stream() - .filter(subquerySymbols::contains) + Set subqueryVariables = SymbolsExtractor.extractUniqueVariable(subquery, symbolAllocator.getTypes()); + List newCorrelation = node.getCorrelation().stream() + .filter(subqueryVariables::contains) .collect(toImmutableList()); - Set inputContext = ImmutableSet.builder() + Set inputContext = ImmutableSet.builder() .addAll(context.get()) .addAll(newCorrelation) .build(); PlanNode input = context.rewrite(node.getInput(), inputContext); // remove unused lateral nodes - if (intersection(ImmutableSet.copyOf(input.getOutputSymbols()), inputContext).isEmpty() && isScalar(input)) { + if (intersection(ImmutableSet.copyOf(input.getOutputVariables()), inputContext).isEmpty() && isScalar(input)) { return subquery; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplicateSemiJoinInDelete.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplicateSemiJoinInDelete.java index 79b4a3f7eedbd..0200d866a8c33 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplicateSemiJoinInDelete.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplicateSemiJoinInDelete.java @@ -51,11 +51,11 @@ public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context) node.getId(), sourceRewritten, filteringSourceRewritten, - node.getSourceJoinSymbol(), - node.getFilteringSourceJoinSymbol(), + node.getSourceJoinVariable(), + node.getFilteringSourceJoinVariable(), node.getSemiJoinOutput(), - node.getSourceHashSymbol(), - node.getFilteringSourceHashSymbol(), + node.getSourceHashVariable(), + node.getFilteringSourceHashVariable(), node.getDistributionType()); if (isDeleteQuery) { @@ -77,7 +77,7 @@ public PlanNode visitDelete(DeleteNode node, RewriteContext context) rewrittenSource, node.getTarget(), node.getRowId(), - node.getOutputSymbols()); + node.getOutputVariables()); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java index b3ea9c2cd9416..7d7073eb5f6cd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java @@ -15,10 +15,9 @@ import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator.DecorrelatedNode; @@ -34,6 +33,7 @@ import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -43,6 +43,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -65,23 +66,23 @@ public ScalarAggregationToJoinRewriter(FunctionManager functionManager, SymbolAl this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.lookup = requireNonNull(lookup, "lookup is null"); - this.planNodeDecorrelator = new PlanNodeDecorrelator(idAllocator, lookup); + this.planNodeDecorrelator = new PlanNodeDecorrelator(idAllocator, symbolAllocator, lookup); } public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode aggregation) { - List correlation = lateralJoinNode.getCorrelation(); + List correlation = lateralJoinNode.getCorrelation(); Optional source = planNodeDecorrelator.decorrelateFilters(lookup.resolve(aggregation.getSource()), correlation); if (!source.isPresent()) { return lateralJoinNode; } - Symbol nonNull = symbolAllocator.newSymbol("non_null", BooleanType.BOOLEAN); + VariableReferenceExpression nonNull = symbolAllocator.newVariable("non_null", BooleanType.BOOLEAN); Assignments scalarAggregationSourceAssignments = Assignments.builder() - .putIdentities(source.get().getNode().getOutputSymbols()) + .putIdentities(source.get().getNode().getOutputVariables()) .put(nonNull, TRUE_LITERAL) .build(); - ProjectNode scalarAggregationSourceWithNonNullableSymbol = new ProjectNode( + ProjectNode scalarAggregationSourceWithNonNullableVariable = new ProjectNode( idAllocator.getNextId(), source.get().getNode(), scalarAggregationSourceAssignments); @@ -89,7 +90,7 @@ public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, Aggreg return rewriteScalarAggregation( lateralJoinNode, aggregation, - scalarAggregationSourceWithNonNullableSymbol, + scalarAggregationSourceWithNonNullableVariable, source.get().getCorrelatedPredicates(), nonNull); } @@ -99,12 +100,12 @@ private PlanNode rewriteScalarAggregation( AggregationNode scalarAggregation, PlanNode scalarAggregationSource, Optional joinExpression, - Symbol nonNull) + VariableReferenceExpression nonNull) { AssignUniqueId inputWithUniqueColumns = new AssignUniqueId( idAllocator.getNextId(), lateralJoinNode.getInput(), - symbolAllocator.newSymbol("unique", BigintType.BIGINT)); + symbolAllocator.newVariable("unique", BIGINT)); JoinNode leftOuterJoin = new JoinNode( idAllocator.getNextId(), @@ -112,9 +113,9 @@ private PlanNode rewriteScalarAggregation( inputWithUniqueColumns, scalarAggregationSource, ImmutableList.of(), - ImmutableList.builder() - .addAll(inputWithUniqueColumns.getOutputSymbols()) - .addAll(scalarAggregationSource.getOutputSymbols()) + ImmutableList.builder() + .addAll(inputWithUniqueColumns.getOutputVariables()) + .addAll(scalarAggregationSource.getOutputVariables()) .build(), joinExpression.map(OriginalExpressionUtils::castToRowExpression), Optional.empty(), @@ -135,11 +136,11 @@ private PlanNode rewriteScalarAggregation( .recurseOnlyWhen(EnforceSingleRowNode.class::isInstance) .findFirst(); - List aggregationOutputSymbols = getTruncatedAggregationSymbols(lateralJoinNode, aggregationNode.get()); + List aggregationOutputVariables = getTruncatedAggregationVariables(lateralJoinNode, aggregationNode.get()); if (subqueryProjection.isPresent()) { Assignments assignments = Assignments.builder() - .putIdentities(aggregationOutputSymbols) + .putIdentities(aggregationOutputVariables) .putAll(subqueryProjection.get().getAssignments()) .build(); @@ -152,38 +153,38 @@ private PlanNode rewriteScalarAggregation( return new ProjectNode( idAllocator.getNextId(), aggregationNode.get(), - Assignments.identity(aggregationOutputSymbols)); + Assignments.identity(aggregationOutputVariables)); } } - private static List getTruncatedAggregationSymbols(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode) + private List getTruncatedAggregationVariables(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode) { - Set applySymbols = new HashSet<>(lateralJoinNode.getOutputSymbols()); - return aggregationNode.getOutputSymbols().stream() - .filter(applySymbols::contains) + Set applyVariables = new HashSet<>(lateralJoinNode.getOutputVariables()); + return aggregationNode.getOutputVariables().stream() + .filter(applyVariables::contains) .collect(toImmutableList()); } private Optional createAggregationNode( AggregationNode scalarAggregation, JoinNode leftOuterJoin, - Symbol nonNullableAggregationSourceSymbol) + VariableReferenceExpression nonNull) { - ImmutableMap.Builder aggregations = ImmutableMap.builder(); - for (Map.Entry entry : scalarAggregation.getAggregations().entrySet()) { - Symbol symbol = entry.getKey(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : scalarAggregation.getAggregations().entrySet()) { + VariableReferenceExpression variable = entry.getKey(); if (functionResolution.isCountFunction(entry.getValue().getFunctionHandle())) { - Type scalarAggregationSourceType = symbolAllocator.getTypes().get(nonNullableAggregationSourceSymbol); - aggregations.put(symbol, new Aggregation( + Type scalarAggregationSourceType = nonNull.getType(); + aggregations.put(variable, new Aggregation( functionResolution.countFunction(scalarAggregationSourceType), - ImmutableList.of(nonNullableAggregationSourceSymbol.toSymbolReference()), + ImmutableList.of(new SymbolReference(nonNull.getName())), Optional.empty(), Optional.empty(), false, entry.getValue().getMask())); } else { - aggregations.put(symbol, entry.getValue()); + aggregations.put(variable, entry.getValue()); } } @@ -191,10 +192,10 @@ private Optional createAggregationNode( idAllocator.getNextId(), leftOuterJoin, aggregations.build(), - singleGroupingSet(leftOuterJoin.getLeft().getOutputSymbols()), + singleGroupingSet(leftOuterJoin.getLeft().getOutputVariables()), ImmutableList.of(), scalarAggregation.getStep(), - scalarAggregation.getHashSymbol(), + scalarAggregation.getHashVariable(), Optional.empty())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java index 21dcb95e6c7fc..75c82da75fb18 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java @@ -16,7 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -64,33 +64,36 @@ public PlanNode visitPlan(PlanNode node, RewriteContext context) public PlanNode visitUnion(UnionNode node, RewriteContext context) { ImmutableList.Builder flattenedSources = ImmutableList.builder(); - ImmutableListMultimap.Builder flattenedSymbolMap = ImmutableListMultimap.builder(); - flattenSetOperation(node, context, flattenedSources, flattenedSymbolMap); + ImmutableListMultimap.Builder flattenedVariableMap = ImmutableListMultimap.builder(); + flattenSetOperation(node, context, flattenedSources, flattenedVariableMap); - return new UnionNode(node.getId(), flattenedSources.build(), flattenedSymbolMap.build(), ImmutableList.copyOf(flattenedSymbolMap.build().keySet())); + return new UnionNode(node.getId(), flattenedSources.build(), flattenedVariableMap.build()); } @Override public PlanNode visitIntersect(IntersectNode node, RewriteContext context) { ImmutableList.Builder flattenedSources = ImmutableList.builder(); - ImmutableListMultimap.Builder flattenedSymbolMap = ImmutableListMultimap.builder(); - flattenSetOperation(node, context, flattenedSources, flattenedSymbolMap); + ImmutableListMultimap.Builder flattenedVariableMap = ImmutableListMultimap.builder(); + flattenSetOperation(node, context, flattenedSources, flattenedVariableMap); - return new IntersectNode(node.getId(), flattenedSources.build(), flattenedSymbolMap.build(), ImmutableList.copyOf(flattenedSymbolMap.build().keySet())); + return new IntersectNode(node.getId(), flattenedSources.build(), flattenedVariableMap.build()); } @Override public PlanNode visitExcept(ExceptNode node, RewriteContext context) { ImmutableList.Builder flattenedSources = ImmutableList.builder(); - ImmutableListMultimap.Builder flattenedSymbolMap = ImmutableListMultimap.builder(); - flattenSetOperation(node, context, flattenedSources, flattenedSymbolMap); + ImmutableListMultimap.Builder flattenedVariableMap = ImmutableListMultimap.builder(); + flattenSetOperation(node, context, flattenedSources, flattenedVariableMap); - return new ExceptNode(node.getId(), flattenedSources.build(), flattenedSymbolMap.build(), ImmutableList.copyOf(flattenedSymbolMap.build().keySet())); + return new ExceptNode(node.getId(), flattenedSources.build(), flattenedVariableMap.build()); } - private static void flattenSetOperation(SetOperationNode node, RewriteContext context, ImmutableList.Builder flattenedSources, ImmutableListMultimap.Builder flattenedSymbolMap) + private static void flattenSetOperation( + SetOperationNode node, RewriteContext context, + ImmutableList.Builder flattenedSources, + ImmutableListMultimap.Builder flattenedSymbolMap) { for (int i = 0; i < node.getSources().size(); i++) { PlanNode subplan = node.getSources().get(i); @@ -102,14 +105,14 @@ private static void flattenSetOperation(SetOperationNode node, RewriteContext> entry : node.getSymbolMapping().asMap().entrySet()) { - Symbol inputSymbol = Iterables.get(entry.getValue(), i); - flattenedSymbolMap.putAll(entry.getKey(), rewrittenSetOperation.getSymbolMapping().get(inputSymbol)); + for (Map.Entry> entry : node.getVariableMapping().asMap().entrySet()) { + VariableReferenceExpression inputVariable = Iterables.get(entry.getValue(), i); + flattenedSymbolMap.putAll(entry.getKey(), rewrittenSetOperation.getVariableMapping().get(inputVariable)); } } else { flattenedSources.add(rewrittenSource); - for (Map.Entry> entry : node.getSymbolMapping().asMap().entrySet()) { + for (Map.Entry> entry : node.getVariableMapping().asMap().entrySet()) { flattenedSymbolMap.put(entry.getKey(), Iterables.get(entry.getValue(), i)); } } @@ -135,8 +138,8 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext c node.getGroupingSets(), ImmutableList.of(), node.getStep(), - node.getHashSymbol(), - node.getGroupIdSymbol()); + node.getHashVariable(), + node.getGroupIdVariable()); } private static boolean isDistinctOperator(AggregationNode node) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java index dcc1c555e5f64..adf5a5a2f3577 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution; import com.google.common.collect.ImmutableList; @@ -42,11 +42,11 @@ class StreamPreferredProperties private final Optional distribution; private final boolean exactColumnOrder; - private final Optional> partitioningColumns; // if missing => any partitioning scheme is acceptable + private final Optional> partitioningColumns; // if missing => any partitioning scheme is acceptable private final boolean orderSensitive; - private StreamPreferredProperties(Optional distribution, Optional> partitioningColumns, boolean orderSensitive) + private StreamPreferredProperties(Optional distribution, Optional> partitioningColumns, boolean orderSensitive) { this(distribution, false, partitioningColumns, orderSensitive); } @@ -54,7 +54,7 @@ private StreamPreferredProperties(Optional distribution, Opt private StreamPreferredProperties( Optional distribution, boolean exactColumnOrder, - Optional> partitioningColumns, + Optional> partitioningColumns, boolean orderSensitive) { this.distribution = requireNonNull(distribution, "distribution is null"); @@ -105,14 +105,14 @@ public StreamPreferredProperties withFixedParallelism() return fixedParallelism(); } - public static StreamPreferredProperties exactlyPartitionedOn(Collection partitionSymbols) + public static StreamPreferredProperties exactlyPartitionedOn(Collection partitionVariables) { - if (partitionSymbols.isEmpty()) { + if (partitionVariables.isEmpty()) { return singleStream(); } // this must be the exact partitioning symbols, in the exact order - return new StreamPreferredProperties(Optional.of(FIXED), true, Optional.of(ImmutableList.copyOf(partitionSymbols)), false); + return new StreamPreferredProperties(Optional.of(FIXED), true, Optional.of(ImmutableList.copyOf(partitionVariables)), false); } public StreamPreferredProperties withoutPreference() @@ -120,13 +120,13 @@ public StreamPreferredProperties withoutPreference() return new StreamPreferredProperties(Optional.empty(), Optional.empty(), orderSensitive); } - public StreamPreferredProperties withPartitioning(Collection partitionSymbols) + public StreamPreferredProperties withPartitioning(Collection partitionVariables) { - if (partitionSymbols.isEmpty()) { + if (partitionVariables.isEmpty()) { return singleStream(); } - Iterable desiredPartitioning = partitionSymbols; + Iterable desiredPartitioning = partitionVariables; if (partitioningColumns.isPresent()) { if (exactColumnOrder) { if (partitioningColumns.get().equals(desiredPartitioning)) { @@ -135,7 +135,7 @@ public StreamPreferredProperties withPartitioning(Collection partitionSy } else { // If there are common columns between our requirements and the desired partitionSymbols, both can be satisfied in one shot - Set common = Sets.intersection(ImmutableSet.copyOf(desiredPartitioning), ImmutableSet.copyOf(partitioningColumns.get())); + Set common = Sets.intersection(ImmutableSet.copyOf(desiredPartitioning), ImmutableSet.copyOf(partitioningColumns.get())); // If we find common partitioning columns, use them, else use child's partitioning columns if (!common.isEmpty()) { @@ -206,7 +206,7 @@ public boolean isParallelPreferred() return distribution.isPresent() && distribution.get() != SINGLE; } - public Optional> getPartitioningColumns() + public Optional> getPartitioningColumns() { return partitioningColumns; } @@ -216,19 +216,19 @@ public boolean isOrderSensitive() return orderSensitive; } - public StreamPreferredProperties translate(Function> translator) + public StreamPreferredProperties translate(Function> translator) { return new StreamPreferredProperties( distribution, - partitioningColumns.flatMap(partitioning -> translateSymbols(partitioning, translator)), + partitioningColumns.flatMap(partitioning -> translateVariables(partitioning, translator)), orderSensitive); } - private static Optional> translateSymbols(Iterable partitioning, Function> translator) + private static Optional> translateVariables(Iterable partitioning, Function> translator) { - ImmutableList.Builder newPartitioningColumns = ImmutableList.builder(); - for (Symbol partitioningColumn : partitioning) { - Optional translated = translator.apply(partitioningColumn); + ImmutableList.Builder newPartitioningColumns = ImmutableList.builder(); + for (VariableReferenceExpression partitioningColumn : partitioning) { + Optional translated = translator.apply(partitioningColumn); if (!translated.isPresent()) { return Optional.empty(); } @@ -252,22 +252,22 @@ public StreamPreferredProperties withOrderSensitivity() return new StreamPreferredProperties(distribution, false, Optional.empty(), true); } - public StreamPreferredProperties constrainTo(Iterable symbols) + public StreamPreferredProperties constrainTo(Iterable variables) { if (!partitioningColumns.isPresent()) { return this; } - ImmutableSet availableSymbols = ImmutableSet.copyOf(symbols); + ImmutableSet availableVariables = ImmutableSet.copyOf(variables); if (exactColumnOrder) { - if (availableSymbols.containsAll(partitioningColumns.get())) { + if (availableVariables.containsAll(partitioningColumns.get())) { return this; } return any(); } - List common = partitioningColumns.get().stream() - .filter(availableSymbols::contains) + List common = partitioningColumns.get().stream() + .filter(availableVariables::contains) .collect(toImmutableList()); if (common.isEmpty()) { return any(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index b2c4fde033dbe..194382e1c7193 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -18,9 +18,9 @@ import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.LocalProperty; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -57,8 +57,6 @@ import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; -import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -77,8 +75,9 @@ import java.util.function.Function; import java.util.stream.Collectors; -import static com.facebook.presto.spi.predicate.TupleDomain.extractFixedValues; +import static com.facebook.presto.spi.predicate.TupleDomain.extractFixedValuesToConstantExpressions; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; +import static com.facebook.presto.sql.planner.optimizations.AddExchanges.computeIdentityTranslations; import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.FIXED; import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.MULTIPLE; import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.SINGLE; @@ -129,17 +128,17 @@ public static StreamProperties deriveProperties(PlanNode node, List - verify(node.getOutputSymbols().containsAll(columns), "Stream-level partitioning properties contain columns not present in node's output")); + verify(node.getOutputVariables().containsAll(columns), "Stream-level partitioning properties contain columns not present in node's output")); - Set localPropertyColumns = result.getLocalProperties().stream() + Set localPropertyColumns = result.getLocalProperties().stream() .flatMap(property -> property.getColumns().stream()) .collect(Collectors.toSet()); - verify(node.getOutputSymbols().containsAll(localPropertyColumns), "Stream-level local properties contain columns not present in node's output"); + verify(node.getOutputVariables().containsAll(localPropertyColumns), "Stream-level local properties contain columns not present in node's output"); return result; } @@ -149,11 +148,13 @@ private static class Visitor { private final Metadata metadata; private final Session session; + private final TypeProvider types; - private Visitor(Metadata metadata, Session session) + private Visitor(Metadata metadata, Session session, TypeProvider types) { this.metadata = metadata; this.session = session; + this.types = types; } @Override @@ -170,16 +171,17 @@ protected StreamProperties visitPlan(PlanNode node, List input public StreamProperties visitJoin(JoinNode node, List inputProperties) { StreamProperties leftProperties = inputProperties.get(0); + List outputs = node.getOutputVariables(); boolean unordered = PropertyDerivations.spillPossible(session, node.getType()); switch (node.getType()) { case INNER: return leftProperties - .translate(column -> PropertyDerivations.filterOrRewrite(node.getOutputSymbols(), node.getCriteria(), column)) + .translate(column -> PropertyDerivations.filterOrRewrite(outputs, node.getCriteria(), column)) .unordered(unordered); case LEFT: return leftProperties - .translate(column -> PropertyDerivations.filterIfMissing(node.getOutputSymbols(), column)) + .translate(column -> PropertyDerivations.filterIfMissing(outputs, column)) .unordered(unordered); case RIGHT: // since this is a right join, none of the matched output rows will contain nulls @@ -210,7 +212,7 @@ public StreamProperties visitSpatialJoin(SpatialJoinNode node, List PropertyDerivations.filterIfMissing(node.getOutputSymbols(), column)); + return leftProperties.translate(column -> PropertyDerivations.filterIfMissing(node.getOutputVariables(), column)); default: throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType()); } @@ -248,17 +250,17 @@ public StreamProperties visitValues(ValuesNode node, List cont public StreamProperties visitTableScan(TableScanNode node, List inputProperties) { TableLayout layout = metadata.getLayout(session, node.getTable()); - Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); + Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); // Globally constant assignments Set constants = new HashSet<>(); - extractFixedValues(node.getCurrentConstraint()).orElse(ImmutableMap.of()) + extractFixedValuesToConstantExpressions(node.getCurrentConstraint()).orElse(ImmutableMap.of()) .entrySet().stream() .filter(entry -> !entry.getValue().isNull()) // TODO consider allowing nulls .forEach(entry -> constants.add(entry.getKey())); - Optional> streamPartitionSymbols = layout.getStreamPartitioningColumns() - .flatMap(columns -> getNonConstantSymbols(columns, assignments, constants)); + Optional> streamPartitionSymbols = layout.getStreamPartitioningColumns() + .flatMap(columns -> getNonConstantVariables(columns, assignments, constants)); // if we are partitioned on empty set, we must say multiple of unknown partitioning, because // the connector does not guarantee a single split in this case (since it might not understand @@ -269,16 +271,16 @@ public StreamProperties visitTableScan(TableScanNode node, List> getNonConstantSymbols(Set columnHandles, Map assignments, Set globalConstants) + private Optional> getNonConstantVariables(Set columnHandles, Map assignments, Set globalConstants) { // Strip off the constants from the partitioning columns (since those are not required for translation) Set constantsStrippedPartitionColumns = columnHandles.stream() .filter(column -> !globalConstants.contains(column)) .collect(toImmutableSet()); - ImmutableSet.Builder builder = ImmutableSet.builder(); + ImmutableSet.Builder builder = ImmutableSet.builder(); for (ColumnHandle column : constantsStrippedPartitionColumns) { - Symbol translated = assignments.get(column); + VariableReferenceExpression translated = assignments.get(column); if (translated == null) { return Optional.empty(); } @@ -311,7 +313,7 @@ public StreamProperties visitExchange(ExchangeNode node, List return new StreamProperties( FIXED, Optional.of(node.getPartitioningScheme().getPartitioning().getArguments().stream() - .map(ArgumentBinding::getColumn) + .map(ArgumentBinding::getVariableReference) .collect(toImmutableList())), false); case REPLICATE: return new StreamProperties(MULTIPLE, Optional.empty(), false); @@ -330,26 +332,16 @@ public StreamProperties visitProject(ProjectNode node, List in StreamProperties properties = Iterables.getOnlyElement(inputProperties); // We can describe properties in terms of inputs that are projected unmodified (i.e., identity projections) - Map identities = computeIdentityTranslations(node.getAssignments().getMap()); - return properties.translate(column -> Optional.ofNullable(identities.get(column))); - } + Map identities = computeIdentityTranslations(node.getAssignments(), types); - private static Map computeIdentityTranslations(Map assignments) - { - Map inputToOutput = new HashMap<>(); - for (Map.Entry assignment : assignments.entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { - inputToOutput.put(Symbol.from(assignment.getValue()), assignment.getKey()); - } - } - return inputToOutput; + return properties.translate(column -> Optional.ofNullable(identities.get(column))); } @Override public StreamProperties visitGroupId(GroupIdNode node, List inputProperties) { - Map inputToOutputMappings = new HashMap<>(); - for (Map.Entry setMapping : node.getGroupingColumns().entrySet()) { + Map inputToOutputMappings = new HashMap<>(); + for (Map.Entry setMapping : node.getGroupingColumns().entrySet()) { if (node.getCommonGroupingColumns().contains(setMapping.getKey())) { // TODO: Add support for translating a property on a single column to multiple columns // when GroupIdNode is copying a single input grouping column into multiple output grouping columns (i.e. aliases), this is basically picking one arbitrarily @@ -359,7 +351,7 @@ public StreamProperties visitGroupId(GroupIdNode node, List in // TODO: Add support for translating a property on a single column to multiple columns // this is deliberately placed after the grouping columns, because preserving properties has a bigger perf impact - for (Symbol argument : node.getAggregationArguments()) { + for (VariableReferenceExpression argument : node.getAggregationArguments()) { inputToOutputMappings.putIfAbsent(argument, argument); } @@ -372,7 +364,7 @@ public StreamProperties visitAggregation(AggregationNode node, List node.getGroupingKeys().contains(symbol) ? Optional.of(symbol) : Optional.empty()); + return properties.translate(variable -> node.getGroupingKeys().contains(variable) ? Optional.of(variable) : Optional.empty()); } @Override @@ -413,7 +405,7 @@ public StreamProperties visitUnnest(UnnestNode node, List inpu StreamProperties properties = Iterables.getOnlyElement(inputProperties); // We can describe properties in terms of inputs that are projected unmodified (i.e., not the unnested symbols) - Set passThroughInputs = ImmutableSet.copyOf(node.getReplicateSymbols()); + Set passThroughInputs = ImmutableSet.copyOf(node.getReplicateVariables()); return properties.translate(column -> { if (passThroughInputs.contains(column)) { return Optional.of(column); @@ -463,7 +455,7 @@ public StreamProperties visitAssignUniqueId(AssignUniqueId node, List inputProperties) { return Iterables.getOnlyElement(inputProperties) - .translate(column -> PropertyDerivations.filterIfMissing(node.getOutputSymbols(), column)); + .translate(column -> PropertyDerivations.filterIfMissing(node.getOutputVariables(), column)); } @Override @@ -577,7 +569,7 @@ public enum StreamDistribution private final StreamDistribution distribution; - private final Optional> partitioningColumns; // if missing => partitioned with some unknown scheme + private final Optional> partitioningColumns; // if missing => partitioned with some unknown scheme private final boolean ordered; @@ -588,14 +580,14 @@ public enum StreamDistribution // NOTE: Partitioning on zero columns (or effectively zero columns if the columns are constant) indicates that all // the rows will be partitioned into a single stream. - private StreamProperties(StreamDistribution distribution, Optional> partitioningColumns, boolean ordered) + private StreamProperties(StreamDistribution distribution, Optional> partitioningColumns, boolean ordered) { this(distribution, partitioningColumns, ordered, null); } private StreamProperties( StreamDistribution distribution, - Optional> partitioningColumns, + Optional> partitioningColumns, boolean ordered, ActualProperties otherActualProperties) { @@ -615,7 +607,7 @@ private StreamProperties( this.otherActualProperties = otherActualProperties; } - public List> getLocalProperties() + public List> getLocalProperties() { checkState(otherActualProperties != null, "otherActualProperties not set"); return otherActualProperties.getLocalProperties(); @@ -664,12 +656,12 @@ public StreamDistribution getDistribution() return distribution; } - public boolean isExactlyPartitionedOn(Iterable columns) + public boolean isExactlyPartitionedOn(Iterable columns) { return partitioningColumns.isPresent() && columns.equals(ImmutableList.copyOf(partitioningColumns.get())); } - public boolean isPartitionedOn(Iterable columns) + public boolean isPartitionedOn(Iterable columns) { if (!partitioningColumns.isPresent()) { return false; @@ -701,14 +693,14 @@ private StreamProperties withOtherActualProperties(ActualProperties actualProper return new StreamProperties(distribution, partitioningColumns, ordered, actualProperties); } - public StreamProperties translate(Function> translator) + public StreamProperties translate(Function> translator) { return new StreamProperties( distribution, partitioningColumns.flatMap(partitioning -> { - ImmutableList.Builder newPartitioningColumns = ImmutableList.builder(); - for (Symbol partitioningColumn : partitioning) { - Optional translated = translator.apply(partitioningColumn); + ImmutableList.Builder newPartitioningColumns = ImmutableList.builder(); + for (VariableReferenceExpression partitioningColumn : partitioning) { + Optional translated = translator.apply(partitioningColumn); if (!translated.isPresent()) { return Optional.empty(); } @@ -719,7 +711,7 @@ public StreamProperties translate(Function> translator) ordered, otherActualProperties.translate(translator)); } - public Optional> getPartitioningColumns() + public Optional> getPartitioningColumns() { return partitioningColumns; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 5dbc405a2aa11..0eccca450f6cb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -16,9 +16,11 @@ import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -36,6 +38,7 @@ import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import java.util.HashSet; import java.util.List; @@ -52,20 +55,47 @@ public class SymbolMapper { - private final Map mapping; + private final Map mapping; + private final TypeProvider types; - public SymbolMapper(Map mapping) + public SymbolMapper(Map mapping) { - this.mapping = ImmutableMap.copyOf(requireNonNull(mapping, "mapping is null")); + requireNonNull(mapping, "mapping is null"); + this.mapping = mapping.entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), entry -> entry.getValue().getName())); + ImmutableSet.Builder variables = ImmutableSet.builder(); + mapping.entrySet().forEach(entry -> { + variables.add(entry.getKey()); + variables.add(entry.getValue()); + }); + this.types = TypeProvider.fromVariables(variables.build()); + } + + public SymbolMapper(Map mapping, TypeProvider types) + { + requireNonNull(mapping, "mapping is null"); + this.mapping = ImmutableMap.copyOf(mapping); + this.types = requireNonNull(types, "types is null"); } public Symbol map(Symbol symbol) { - Symbol canonical = symbol; + String canonical = symbol.getName(); + while (mapping.containsKey(canonical) && !mapping.get(canonical).equals(canonical)) { + canonical = mapping.get(canonical); + } + return new Symbol(canonical); + } + + public VariableReferenceExpression map(VariableReferenceExpression variable) + { + String canonical = variable.getName(); while (mapping.containsKey(canonical) && !mapping.get(canonical).equals(canonical)) { canonical = mapping.get(canonical); } - return canonical; + if (canonical.equals(variable.getName())) { + return variable; + } + return new VariableReferenceExpression(canonical, types.get(new Symbol(canonical))); } public Expression map(Expression value) @@ -106,8 +136,8 @@ public AggregationNode map(AggregationNode node, PlanNode source, PlanNodeIdAllo private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId) { - ImmutableMap.Builder aggregations = ImmutableMap.builder(); - for (Entry entry : node.getAggregations().entrySet()) { + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Entry entry : node.getAggregations().entrySet()) { aggregations.put(map(entry.getKey()), map(entry.getValue())); } @@ -116,13 +146,13 @@ private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId ne source, aggregations.build(), groupingSets( - mapAndDistinct(node.getGroupingKeys()), + mapAndDistinctVariable(node.getGroupingKeys()), node.getGroupingSetCount(), node.getGlobalGroupingSets()), ImmutableList.of(), node.getStep(), - node.getHashSymbol().map(this::map), - node.getGroupIdSymbol().map(this::map)); + node.getHashVariable().map(this::map), + node.getGroupIdVariable().map(this::map)); } private Aggregation map(Aggregation aggregation) @@ -138,15 +168,15 @@ private Aggregation map(Aggregation aggregation) public TopNNode map(TopNNode node, PlanNode source, PlanNodeId newNodeId) { - ImmutableList.Builder symbols = ImmutableList.builder(); - ImmutableMap.Builder orderings = ImmutableMap.builder(); - Set seenCanonicals = new HashSet<>(node.getOrderingScheme().getOrderBy().size()); - for (Symbol symbol : node.getOrderingScheme().getOrderBy()) { - Symbol canonical = map(symbol); + ImmutableList.Builder variables = ImmutableList.builder(); + ImmutableMap.Builder orderings = ImmutableMap.builder(); + Set seenCanonicals = new HashSet<>(node.getOrderingScheme().getOrderBy().size()); + for (VariableReferenceExpression variable : node.getOrderingScheme().getOrderBy()) { + VariableReferenceExpression canonical = map(variable); if (seenCanonicals.add(canonical)) { seenCanonicals.add(canonical); - symbols.add(canonical); - orderings.put(canonical, node.getOrderingScheme().getOrdering(symbol)); + variables.add(canonical); + orderings.put(canonical, node.getOrderingScheme().getOrdering(variable)); } } @@ -154,7 +184,7 @@ public TopNNode map(TopNNode node, PlanNode source, PlanNodeId newNodeId) newNodeId, source, node.getCount(), - new OrderingScheme(symbols.build(), orderings.build()), + new OrderingScheme(variables.build(), orderings.build()), node.getStep()); } @@ -166,7 +196,7 @@ public TableWriterNode map(TableWriterNode node, PlanNode source) public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId newNodeId) { // Intentionally does not use canonicalizeAndDistinct as that would remove columns - ImmutableList columns = node.getColumns().stream() + ImmutableList columns = node.getColumns().stream() .map(this::map) .collect(toImmutableList()); @@ -174,9 +204,9 @@ public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId new newNodeId, source, node.getTarget(), - map(node.getRowCountSymbol()), - map(node.getFragmentSymbol()), - map(node.getTableCommitContextSymbol()), + map(node.getRowCountVariable()), + map(node.getFragmentVariable()), + map(node.getTableCommitContextVariable()), columns, node.getColumnNames(), node.getPartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source)), @@ -190,7 +220,7 @@ public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source) node.getId(), source, node.getTarget(), - node.getRowCountSymbol(), + node.getRowCountVariable(), node.isRowCountEnabled(), node.getDescriptor().map(this::map)); } @@ -201,7 +231,7 @@ public TableFinishNode map(TableFinishNode node, PlanNode source) node.getId(), source, node.getTarget(), - map(node.getRowCountSymbol()), + map(node.getRowCountVariable()), node.getStatisticsAggregation().map(this::map), node.getStatisticsAggregationDescriptor().map(descriptor -> descriptor.map(this::map))); } @@ -210,7 +240,7 @@ private PartitioningScheme canonicalize(PartitioningScheme scheme, PlanNode sour { return new PartitioningScheme( scheme.getPartitioning().translate(this::map), - mapAndDistinct(source.getOutputSymbols()), + mapAndDistinctVariable(source.getOutputVariables()), scheme.getHashColumn().map(this::map), scheme.isReplicateNullsAndAny(), scheme.getBucketToPartition()); @@ -218,24 +248,17 @@ private PartitioningScheme canonicalize(PartitioningScheme scheme, PlanNode sour private StatisticAggregations map(StatisticAggregations statisticAggregations) { - Map aggregations = statisticAggregations.getAggregations().entrySet().stream() + Map aggregations = statisticAggregations.getAggregations().entrySet().stream() .collect(toImmutableMap(entry -> map(entry.getKey()), entry -> map(entry.getValue()))); - return new StatisticAggregations(aggregations, mapAndDistinct(statisticAggregations.getGroupingSymbols())); + return new StatisticAggregations(aggregations, mapAndDistinctVariable(statisticAggregations.getGroupingVariables())); } - private StatisticAggregationsDescriptor map(StatisticAggregationsDescriptor descriptor) + private StatisticAggregationsDescriptor map(StatisticAggregationsDescriptor descriptor) { return descriptor.map(this::map); } - private List map(List outputs) - { - return outputs.stream() - .map(this::map) - .collect(toImmutableList()); - } - - private List mapAndDistinct(List outputs) + private List mapAndDistinctSymbol(List outputs) { Set added = new HashSet<>(); ImmutableList.Builder builder = ImmutableList.builder(); @@ -248,6 +271,19 @@ private List mapAndDistinct(List outputs) return builder.build(); } + private List mapAndDistinctVariable(List outputs) + { + Set added = new HashSet<>(); + ImmutableList.Builder builder = ImmutableList.builder(); + for (VariableReferenceExpression variable : outputs) { + VariableReferenceExpression canonical = map(variable); + if (added.add(canonical)) { + builder.add(canonical); + } + } + return builder.build(); + } + public static SymbolMapper.Builder builder() { return new Builder(); @@ -255,16 +291,16 @@ public static SymbolMapper.Builder builder() public static class Builder { - private final ImmutableMap.Builder mappings = ImmutableMap.builder(); + private final ImmutableMap.Builder mappingsBuilder = ImmutableMap.builder(); public SymbolMapper build() { - return new SymbolMapper(mappings.build()); + return new SymbolMapper(mappingsBuilder.build()); } - public void put(Symbol from, Symbol to) + public void put(VariableReferenceExpression from, VariableReferenceExpression to) { - mappings.put(from, to); + mappingsBuilder.put(from, to); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java index 506005d9d1810..62b47a3e6f96e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java @@ -18,6 +18,7 @@ import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.Type; @@ -43,6 +44,7 @@ import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.facebook.presto.sql.tree.SearchedCaseExpression; import com.facebook.presto.sql.tree.SimpleCaseExpression; +import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.WhenClause; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -83,24 +85,20 @@ public TransformQuantifiedComparisonApplyToLateralJoin(FunctionManager functionM @Override public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) { - return rewriteWith(new Rewriter(functionResolution, session, idAllocator, types, symbolAllocator), plan, null); + return rewriteWith(new Rewriter(functionResolution, idAllocator, symbolAllocator), plan, null); } private static class Rewriter extends SimplePlanRewriter { private final StandardFunctionResolution functionResolution; - private final Session session; private final PlanNodeIdAllocator idAllocator; - private final TypeProvider types; private final SymbolAllocator symbolAllocator; - public Rewriter(StandardFunctionResolution functionResolution, Session session, PlanNodeIdAllocator idAllocator, TypeProvider types, SymbolAllocator symbolAllocator) + public Rewriter(StandardFunctionResolution functionResolution, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) { this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); - this.session = requireNonNull(session, "session is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); - this.types = requireNonNull(types, "types is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); } @@ -125,16 +123,16 @@ private PlanNode rewriteQuantifiedApplyNode(ApplyNode node, QuantifiedComparison { PlanNode subqueryPlan = context.rewrite(node.getSubquery()); - Symbol outputColumn = getOnlyElement(subqueryPlan.getOutputSymbols()); - Type outputColumnType = types.get(outputColumn); + VariableReferenceExpression outputColumn = getOnlyElement(subqueryPlan.getOutputVariables()); + Type outputColumnType = outputColumn.getType(); checkState(outputColumnType.isOrderable(), "Subquery result type must be orderable"); - Symbol minValue = symbolAllocator.newSymbol("min", outputColumnType); - Symbol maxValue = symbolAllocator.newSymbol("max", outputColumnType); - Symbol countAllValue = symbolAllocator.newSymbol("count_all", BigintType.BIGINT); - Symbol countNonNullValue = symbolAllocator.newSymbol("count_non_null", BigintType.BIGINT); + VariableReferenceExpression minValue = symbolAllocator.newVariable("min", outputColumnType); + VariableReferenceExpression maxValue = symbolAllocator.newVariable("max", outputColumnType); + VariableReferenceExpression countAllValue = symbolAllocator.newVariable("count_all", BigintType.BIGINT); + VariableReferenceExpression countNonNullValue = symbolAllocator.newVariable("count_non_null", BigintType.BIGINT); - List outputColumnReferences = ImmutableList.of(outputColumn.toSymbolReference()); + List outputColumnReferences = ImmutableList.of(new SymbolReference(outputColumn.getName())); subqueryPlan = new AggregationNode( idAllocator.getNextId(), @@ -182,11 +180,16 @@ countNonNullValue, new Aggregation( LateralJoinNode.Type.INNER, node.getOriginSubqueryError()); - Expression valueComparedToSubquery = rewriteUsingBounds(quantifiedComparison, minValue, maxValue, countAllValue, countNonNullValue); + Expression valueComparedToSubquery = rewriteUsingBounds( + quantifiedComparison, + new Symbol(minValue.getName()), + new Symbol(maxValue.getName()), + new Symbol(countAllValue.getName()), + new Symbol(countNonNullValue.getName())); - Symbol quantifiedComparisonSymbol = getOnlyElement(node.getSubqueryAssignments().getSymbols()); + VariableReferenceExpression quantifiedComparisonVariable = getOnlyElement(node.getSubqueryAssignments().getVariables()); - return projectExpressions(lateralJoinNode, Assignments.of(quantifiedComparisonSymbol, valueComparedToSubquery)); + return projectExpressions(lateralJoinNode, Assignments.of(quantifiedComparisonVariable, valueComparedToSubquery)); } public Expression rewriteUsingBounds(QuantifiedComparisonExpression quantifiedComparison, Symbol minValue, Symbol maxValue, Symbol countAllValue, Symbol countNonNullValue) @@ -262,7 +265,7 @@ private static boolean shouldCompareValueWithLowerBound(QuantifiedComparisonExpr private ProjectNode projectExpressions(PlanNode input, Assignments subqueryAssignments) { Assignments assignments = Assignments.builder() - .putIdentities(input.getOutputSymbols()) + .putIdentities(input.getOutputVariables()) .putAll(subqueryAssignments) .build(); return new ProjectNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java index f4afddd2300d5..9a958bbf88d17 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java @@ -19,9 +19,9 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -106,10 +106,10 @@ public Result apply(SpatialJoinNode spatialJoinNode, Captures captures, Context spatialJoinNode.getType(), spatialJoinNode.getLeft(), spatialJoinNode.getRight(), - spatialJoinNode.getOutputSymbols(), + spatialJoinNode.getOutputVariables(), rewritten, - spatialJoinNode.getLeftPartitionSymbol(), - spatialJoinNode.getRightPartitionSymbol(), + spatialJoinNode.getLeftPartitionVariable(), + spatialJoinNode.getRightPartitionVariable(), spatialJoinNode.getKdbTree())); } } @@ -148,10 +148,10 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), - joinNode.getOutputSymbols(), + joinNode.getOutputVariables(), Optional.of(rewritten), - joinNode.getLeftHashSymbol(), - joinNode.getRightHashSymbol(), + joinNode.getLeftHashVariable(), + joinNode.getRightHashVariable(), joinNode.getDistributionType())); } } @@ -170,8 +170,8 @@ public Result apply(WindowNode windowNode, Captures captures, Context context) { checkState(windowNode.getSource() != null); boolean anyRewritten = false; - ImmutableMap.Builder functions = ImmutableMap.builder(); - for (Entry entry : windowNode.getWindowFunctions().entrySet()) { + ImmutableMap.Builder functions = ImmutableMap.builder(); + for (Entry entry : windowNode.getWindowFunctions().entrySet()) { ImmutableList.Builder newArguments = ImmutableList.builder(); CallExpression callExpression = entry.getValue().getFunctionCall(); for (RowExpression argument : callExpression.getArguments()) { @@ -200,7 +200,7 @@ public Result apply(WindowNode windowNode, Captures captures, Context context) windowNode.getSource(), windowNode.getSpecification(), functions.build(), - windowNode.getHashSymbol(), + windowNode.getHashVariable(), windowNode.getPrePartitionedInputs(), windowNode.getPreSortedOrderPrefix())); } @@ -265,7 +265,7 @@ public Result apply(ValuesNode valuesNode, Captures captures, Context context) rows.add(newRow.build()); } if (anyRewritten) { - return Result.ofPlanNode(new ValuesNode(valuesNode.getId(), valuesNode.getOutputSymbols(), rows.build())); + return Result.ofPlanNode(new ValuesNode(valuesNode.getId(), valuesNode.getOutputVariables(), rows.build())); } return Result.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index c7cb17c67ebc2..a8c8e23822a99 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.PartitioningScheme; @@ -127,7 +128,7 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym private static class Rewriter extends SimplePlanRewriter { - private final Map mapping = new HashMap<>(); + private final Map mapping = new HashMap<>(); private final TypeProvider types; private Rewriter(TypeProvider types) @@ -140,7 +141,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont { PlanNode source = context.rewrite(node.getSource()); //TODO: use mapper in other methods - SymbolMapper mapper = new SymbolMapper(mapping); + SymbolMapper mapper = new SymbolMapper(mapping, types); return mapper.map(node, source); } @@ -149,45 +150,44 @@ public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - Map newGroupingMappings = new HashMap<>(); - ImmutableList.Builder> newGroupingSets = ImmutableList.builder(); + Map newGroupingMappings = new HashMap<>(); + ImmutableList.Builder> newGroupingSets = ImmutableList.builder(); - for (List groupingSet : node.getGroupingSets()) { - ImmutableList.Builder newGroupingSet = ImmutableList.builder(); - for (Symbol output : groupingSet) { + for (List groupingSet : node.getGroupingSets()) { + ImmutableList.Builder newGroupingSet = ImmutableList.builder(); + for (VariableReferenceExpression output : groupingSet) { newGroupingMappings.putIfAbsent(canonicalize(output), canonicalize(node.getGroupingColumns().get(output))); newGroupingSet.add(canonicalize(output)); } newGroupingSets.add(newGroupingSet.build()); } - return new GroupIdNode(node.getId(), source, newGroupingSets.build(), newGroupingMappings, canonicalizeAndDistinct(node.getAggregationArguments()), canonicalize(node.getGroupIdSymbol())); + return new GroupIdNode(node.getId(), source, newGroupingSets.build(), newGroupingMappings, canonicalizeAndDistinctVariable(node.getAggregationArguments()), canonicalize(node.getGroupIdVariable())); } @Override public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - return new ExplainAnalyzeNode(node.getId(), source, canonicalize(node.getOutputSymbol()), node.isVerbose()); + return new ExplainAnalyzeNode(node.getId(), source, canonicalize(node.getOutputVariable()), node.isVerbose()); } @Override public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - List symbols = canonicalizeAndDistinct(node.getDistinctSymbols()); - return new MarkDistinctNode(node.getId(), source, canonicalize(node.getMarkerSymbol()), symbols, canonicalize(node.getHashSymbol())); + return new MarkDistinctNode(node.getId(), source, canonicalize(node.getMarkerVariable()), canonicalizeAndDistinctVariable(node.getDistinctVariables()), canonicalize(node.getHashVariable())); } @Override public PlanNode visitUnnest(UnnestNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - ImmutableMap.Builder> builder = ImmutableMap.builder(); - for (Map.Entry> entry : node.getUnnestSymbols().entrySet()) { + ImmutableMap.Builder> builder = ImmutableMap.builder(); + for (Map.Entry> entry : node.getUnnestVariables().entrySet()) { builder.put(canonicalize(entry.getKey()), entry.getValue()); } - return new UnnestNode(node.getId(), source, canonicalizeAndDistinct(node.getReplicateSymbols()), builder.build(), node.getOrdinalitySymbol()); + return new UnnestNode(node.getId(), source, canonicalizeAndDistinctVariable(node.getReplicateVariables()), builder.build(), node.getOrdinalityVariable()); } @Override @@ -195,9 +195,9 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - ImmutableMap.Builder functions = ImmutableMap.builder(); - for (Map.Entry entry : node.getWindowFunctions().entrySet()) { - Symbol symbol = entry.getKey(); + ImmutableMap.Builder functions = ImmutableMap.builder(); + for (Map.Entry entry : node.getWindowFunctions().entrySet()) { + VariableReferenceExpression symbol = entry.getKey(); // Be aware of the CallExpression handling. CallExpression callExpression = entry.getValue().getFunctionCall(); @@ -220,8 +220,8 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) source, canonicalizeAndDistinct(node.getSpecification()), functions.build(), - canonicalize(node.getHashSymbol()), - canonicalize(node.getPrePartitionedInputs()), + canonicalize(node.getHashVariable()), + canonicalizeVariables(node.getPrePartitionedInputs()), node.getPreSortedOrderPrefix()); } @@ -261,19 +261,19 @@ public PlanNode visitExchange(ExchangeNode node, RewriteContext context) mapExchangeNodeSymbols(node); - List> inputs = new ArrayList<>(); + List> inputs = new ArrayList<>(); for (int i = 0; i < node.getInputs().size(); i++) { inputs.add(new ArrayList<>()); } - Set addedOutputs = new HashSet<>(); - ImmutableList.Builder outputs = ImmutableList.builder(); - for (int symbolIndex = 0; symbolIndex < node.getOutputSymbols().size(); symbolIndex++) { - Symbol canonicalOutput = canonicalize(node.getOutputSymbols().get(symbolIndex)); + Set addedOutputs = new HashSet<>(); + ImmutableList.Builder outputs = ImmutableList.builder(); + for (int variableIndex = 0; variableIndex < node.getOutputVariables().size(); variableIndex++) { + VariableReferenceExpression canonicalOutput = canonicalize(node.getOutputVariables().get(variableIndex)); if (addedOutputs.add(canonicalOutput)) { outputs.add(canonicalOutput); for (int i = 0; i < node.getInputs().size(); i++) { - List input = node.getInputs().get(i); - inputs.get(i).add(canonicalize(input.get(symbolIndex))); + List input = node.getInputs().get(i); + inputs.get(i).add(canonicalize(input.get(variableIndex))); } } } @@ -297,14 +297,14 @@ private void mapExchangeNodeSymbols(ExchangeNode node) return; } - // Mapping from list [node.getInput(0).get(symbolIndex), node.getInput(1).get(symbolIndex), ...] to node.getOutputSymbols(symbolIndex). + // Mapping from list [node.getInput(0).get(symbolIndex), node.getInput(1).get(symbolIndex), ...] to node.getOutputVariables(symbolIndex). // All symbols are canonical. - Map, Symbol> inputsToOutputs = new HashMap<>(); + Map, VariableReferenceExpression> inputsToOutputs = new HashMap<>(); // Map each same list of input symbols [I1, I2, ..., In] to the same output symbol O - for (int symbolIndex = 0; symbolIndex < node.getOutputSymbols().size(); symbolIndex++) { - Symbol canonicalOutput = canonicalize(node.getOutputSymbols().get(symbolIndex)); - List canonicalInputs = canonicalizeExchangeNodeInputs(node, symbolIndex); - Symbol output = inputsToOutputs.get(canonicalInputs); + for (int variableIndex = 0; variableIndex < node.getOutputVariables().size(); variableIndex++) { + VariableReferenceExpression canonicalOutput = canonicalize(node.getOutputVariables().get(variableIndex)); + List canonicalInputs = canonicalizeExchangeNodeInputs(node, variableIndex); + VariableReferenceExpression output = inputsToOutputs.get(canonicalInputs); if (output == null || canonicalOutput.equals(output)) { inputsToOutputs.put(canonicalInputs, canonicalOutput); @@ -319,9 +319,9 @@ private void mapExchangeNodeOutputToInputSymbols(ExchangeNode node) { checkState(node.getInputs().size() == 1); - for (int symbolIndex = 0; symbolIndex < node.getOutputSymbols().size(); symbolIndex++) { - Symbol canonicalOutput = canonicalize(node.getOutputSymbols().get(symbolIndex)); - Symbol canonicalInput = canonicalize(node.getInputs().get(0).get(symbolIndex)); + for (int variableIndex = 0; variableIndex < node.getOutputVariables().size(); variableIndex++) { + VariableReferenceExpression canonicalOutput = canonicalize(node.getOutputVariables().get(variableIndex)); + VariableReferenceExpression canonicalInput = canonicalize(node.getInputs().get(0).get(variableIndex)); if (!canonicalOutput.equals(canonicalInput)) { map(canonicalOutput, canonicalInput); @@ -329,7 +329,7 @@ private void mapExchangeNodeOutputToInputSymbols(ExchangeNode node) } } - private List canonicalizeExchangeNodeInputs(ExchangeNode node, int symbolIndex) + private List canonicalizeExchangeNodeInputs(ExchangeNode node, int symbolIndex) { return node.getInputs().stream() .map(input -> canonicalize(input.get(symbolIndex))) @@ -342,7 +342,7 @@ public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext co return new RemoteSourceNode( node.getId(), node.getSourceFragmentIds(), - canonicalizeAndDistinct(node.getOutputSymbols()), + canonicalizeAndDistinctVariable(node.getOutputVariables()), node.getOrderingScheme().map(this::canonicalizeAndDistinct), node.getExchangeType()); } @@ -356,7 +356,7 @@ public PlanNode visitLimit(LimitNode node, RewriteContext context) @Override public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext context) { - return new DistinctLimitNode(node.getId(), context.rewrite(node.getSource()), node.getLimit(), node.isPartial(), canonicalizeAndDistinct(node.getDistinctSymbols()), canonicalize(node.getHashSymbol())); + return new DistinctLimitNode(node.getId(), context.rewrite(node.getSource()), node.getLimit(), node.isPartial(), canonicalizeAndDistinctVariable(node.getDistinctVariables()), canonicalize(node.getHashVariable())); } @Override @@ -378,25 +378,25 @@ public PlanNode visitValues(ValuesNode node, RewriteContext context) }) .collect(toImmutableList())) .collect(toImmutableList()); - List canonicalizedOutputSymbols = canonicalizeAndDistinct(node.getOutputSymbols()); - checkState(node.getOutputSymbols().size() == canonicalizedOutputSymbols.size(), "Values output symbols were pruned"); + List canonicalizedOutputVariables = canonicalizeAndDistinctVariable(node.getOutputVariables()); + checkState(node.getOutputVariables().size() == canonicalizedOutputVariables.size(), "Values output symbols were pruned"); return new ValuesNode( node.getId(), - canonicalizedOutputSymbols, + canonicalizedOutputVariables, canonicalizedRows); } @Override public PlanNode visitDelete(DeleteNode node, RewriteContext context) { - return new DeleteNode(node.getId(), context.rewrite(node.getSource()), node.getTarget(), canonicalize(node.getRowId()), node.getOutputSymbols()); + return new DeleteNode(node.getId(), context.rewrite(node.getSource()), node.getTarget(), canonicalize(node.getRowId()), node.getOutputVariables()); } @Override public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - SymbolMapper mapper = new SymbolMapper(mapping); + SymbolMapper mapper = new SymbolMapper(mapping, types); return mapper.map(node, source); } @@ -404,14 +404,14 @@ public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteCont public PlanNode visitTableFinish(TableFinishNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - SymbolMapper mapper = new SymbolMapper(mapping); + SymbolMapper mapper = new SymbolMapper(mapping, types); return mapper.map(node, source); } @Override public PlanNode visitRowNumber(RowNumberNode node, RewriteContext context) { - return new RowNumberNode(node.getId(), context.rewrite(node.getSource()), canonicalizeAndDistinct(node.getPartitionBy()), canonicalize(node.getRowNumberSymbol()), node.getMaxRowCountPerPartition(), canonicalize(node.getHashSymbol())); + return new RowNumberNode(node.getId(), context.rewrite(node.getSource()), canonicalizeAndDistinctVariable(node.getPartitionBy()), canonicalize(node.getRowNumberVariable()), node.getMaxRowCountPerPartition(), canonicalize(node.getHashVariable())); } @Override @@ -421,10 +421,10 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext node.getId(), context.rewrite(node.getSource()), canonicalizeAndDistinct(node.getSpecification()), - canonicalize(node.getRowNumberSymbol()), + canonicalize(node.getRowNumberVariable()), node.getMaxRowCountPerPartition(), node.isPartial(), - canonicalize(node.getHashSymbol())); + canonicalize(node.getHashVariable())); } @Override @@ -447,7 +447,7 @@ public PlanNode visitOutput(OutputNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - List canonical = Lists.transform(node.getOutputSymbols(), this::canonicalize); + List canonical = Lists.transform(node.getOutputVariables(), this::canonicalize); return new OutputNode(node.getId(), source, node.getColumnNames(), canonical); } @@ -464,7 +464,7 @@ public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext co { PlanNode source = context.rewrite(node.getSource()); - return new AssignUniqueId(node.getId(), source, node.getIdColumn()); + return new AssignUniqueId(node.getId(), source, node.getIdVariable()); } @Override @@ -472,7 +472,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getInput()); PlanNode subquery = context.rewrite(node.getSubquery()); - List canonicalCorrelation = Lists.transform(node.getCorrelation(), this::canonicalize); + List canonicalCorrelation = Lists.transform(node.getCorrelation(), this::canonicalize); return new ApplyNode(node.getId(), source, subquery, canonicalize(node.getSubqueryAssignments()), canonicalCorrelation, node.getOriginSubqueryError()); } @@ -482,7 +482,7 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext cont { PlanNode source = context.rewrite(node.getInput()); PlanNode subquery = context.rewrite(node.getSubquery()); - List canonicalCorrelation = canonicalizeAndDistinct(node.getCorrelation()); + List canonicalCorrelation = canonicalizeAndDistinctVariable(node.getCorrelation()); return new LateralJoinNode(node.getId(), source, subquery, canonicalCorrelation, node.getType(), node.getOriginSubqueryError()); } @@ -492,7 +492,7 @@ public PlanNode visitTopN(TopNNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - SymbolMapper mapper = new SymbolMapper(mapping); + SymbolMapper mapper = new SymbolMapper(mapping, types); return mapper.map(node, source, node.getId()); } @@ -512,13 +512,13 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) List canonicalCriteria = canonicalizeJoinCriteria(node.getCriteria()); Optional canonicalFilter = node.getFilter().map(OriginalExpressionUtils::castToExpression).map(this::canonicalize); - Optional canonicalLeftHashSymbol = canonicalize(node.getLeftHashSymbol()); - Optional canonicalRightHashSymbol = canonicalize(node.getRightHashSymbol()); + Optional canonicalLeftHashVariable = canonicalize(node.getLeftHashVariable()); + Optional canonicalRightHashVariable = canonicalize(node.getRightHashVariable()); if (node.getType().equals(INNER)) { canonicalCriteria.stream() - .filter(clause -> types.get(clause.getLeft()).equals(types.get(clause.getRight()))) - .filter(clause -> node.getOutputSymbols().contains(clause.getLeft())) + .filter(clause -> clause.getLeft().getType().equals(clause.getRight().getType())) + .filter(clause -> node.getOutputVariables().contains(clause.getLeft())) .forEach(clause -> map(clause.getRight(), clause.getLeft())); } @@ -528,10 +528,10 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) left, right, canonicalCriteria, - canonicalizeAndDistinct(node.getOutputSymbols()), + canonicalizeAndDistinctVariable(node.getOutputVariables()), canonicalFilter.map(OriginalExpressionUtils::castToRowExpression), - canonicalLeftHashSymbol, - canonicalRightHashSymbol, + canonicalLeftHashVariable, + canonicalRightHashVariable, node.getDistributionType()); } @@ -545,11 +545,11 @@ public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context) node.getId(), source, filteringSource, - canonicalize(node.getSourceJoinSymbol()), - canonicalize(node.getFilteringSourceJoinSymbol()), + canonicalize(node.getSourceJoinVariable()), + canonicalize(node.getFilteringSourceJoinVariable()), canonicalize(node.getSemiJoinOutput()), - canonicalize(node.getSourceHashSymbol()), - canonicalize(node.getFilteringSourceHashSymbol()), + canonicalize(node.getSourceHashVariable()), + canonicalize(node.getFilteringSourceHashVariable()), node.getDistributionType()); } @@ -559,13 +559,13 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext cont PlanNode left = context.rewrite(node.getLeft()); PlanNode right = context.rewrite(node.getRight()); - return new SpatialJoinNode(node.getId(), node.getType(), left, right, canonicalizeAndDistinct(node.getOutputSymbols()), castToRowExpression(canonicalize(castToExpression(node.getFilter()))), canonicalize(node.getLeftPartitionSymbol()), canonicalize(node.getRightPartitionSymbol()), node.getKdbTree()); + return new SpatialJoinNode(node.getId(), node.getType(), left, right, canonicalizeAndDistinctVariable(node.getOutputVariables()), castToRowExpression(canonicalize(castToExpression(node.getFilter()))), canonicalize(node.getLeftPartitionVariable()), canonicalize(node.getRightPartitionVariable()), node.getKdbTree()); } @Override public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext context) { - return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), canonicalize(node.getLookupSymbols()), node.getOutputSymbols(), node.getAssignments(), node.getCurrentConstraint()); + return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), canonicalizeVariables(node.getLookupVariables()), node.getOutputVariables(), node.getAssignments(), node.getCurrentConstraint()); } @Override @@ -574,25 +574,25 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) PlanNode probeSource = context.rewrite(node.getProbeSource()); PlanNode indexSource = context.rewrite(node.getIndexSource()); - return new IndexJoinNode(node.getId(), node.getType(), probeSource, indexSource, canonicalizeIndexJoinCriteria(node.getCriteria()), canonicalize(node.getProbeHashSymbol()), canonicalize(node.getIndexHashSymbol())); + return new IndexJoinNode(node.getId(), node.getType(), probeSource, indexSource, canonicalizeIndexJoinCriteria(node.getCriteria()), canonicalize(node.getProbeHashVariable()), canonicalize(node.getIndexHashVariable())); } @Override public PlanNode visitUnion(UnionNode node, RewriteContext context) { - return new UnionNode(node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationSymbolMap(node.getSymbolMapping()), canonicalizeAndDistinct(node.getOutputSymbols())); + return new UnionNode(node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationVariableMap(node.getVariableMapping())); } @Override public PlanNode visitIntersect(IntersectNode node, RewriteContext context) { - return new IntersectNode(node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationSymbolMap(node.getSymbolMapping()), canonicalizeAndDistinct(node.getOutputSymbols())); + return new IntersectNode(node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationVariableMap(node.getVariableMapping())); } @Override public PlanNode visitExcept(ExceptNode node, RewriteContext context) { - return new ExceptNode(node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationSymbolMap(node.getSymbolMapping()), canonicalizeAndDistinct(node.getOutputSymbols())); + return new ExceptNode(node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationVariableMap(node.getVariableMapping())); } private static ImmutableList.Builder rewriteSources(SetOperationNode node, RewriteContext context) @@ -608,7 +608,7 @@ private static ImmutableList.Builder rewriteSources(SetOperationNode n public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - SymbolMapper mapper = new SymbolMapper(mapping); + SymbolMapper mapper = new SymbolMapper(mapping, types); return mapper.map(node, source); } @@ -621,66 +621,74 @@ protected PlanNode visitPlan(PlanNode node, RewriteContext context) private void map(Symbol symbol, Symbol canonical) { Preconditions.checkArgument(!symbol.equals(canonical), "Can't map symbol to itself: %s", symbol); - mapping.put(symbol, canonical); + mapping.put(symbol.getName(), canonical.getName()); + } + + private void map(VariableReferenceExpression variable, VariableReferenceExpression canonical) + { + Preconditions.checkArgument(!variable.equals(canonical), "Can't map variable to itself: %s", variable); + mapping.put(variable.getName(), canonical.getName()); } private Assignments canonicalize(Assignments oldAssignments) { - Map computedExpressions = new HashMap<>(); + Map computedExpressions = new HashMap<>(); Assignments.Builder assignments = Assignments.builder(); - for (Map.Entry entry : oldAssignments.getMap().entrySet()) { + for (Map.Entry entry : oldAssignments.getMap().entrySet()) { Expression expression = canonicalize(entry.getValue()); if (expression instanceof SymbolReference) { // Always map a trivial symbol projection - Symbol symbol = Symbol.from(expression); - if (!symbol.equals(entry.getKey())) { - map(entry.getKey(), symbol); + VariableReferenceExpression variable = new VariableReferenceExpression(Symbol.from(expression).getName(), types.get(Symbol.from(expression))); + if (!variable.getName().equals(entry.getKey().getName())) { + map(entry.getKey(), variable); } } else if (ExpressionDeterminismEvaluator.isDeterministic(expression) && !(expression instanceof NullLiteral)) { // Try to map same deterministic expressions within a projection into the same symbol // Omit NullLiterals since those have ambiguous types - Symbol computedSymbol = computedExpressions.get(expression); - if (computedSymbol == null) { + VariableReferenceExpression computedVariable = computedExpressions.get(expression); + if (computedVariable == null) { // If we haven't seen the expression before in this projection, record it computedExpressions.put(expression, entry.getKey()); } else { // If we have seen the expression before and if it is deterministic - // then we can rewrite references to the current symbol in terms of the parallel computedSymbol in the projection - map(entry.getKey(), computedSymbol); + // then we can rewrite references to the current symbol in terms of the parallel computedVariable in the projection + map(entry.getKey(), computedVariable); } } - Symbol canonical = canonicalize(entry.getKey()); + VariableReferenceExpression canonical = canonicalize(entry.getKey()); assignments.put(canonical, expression); } return assignments.build(); } - private Optional canonicalize(Optional symbol) + private Symbol canonicalize(Symbol symbol) { - if (symbol.isPresent()) { - return Optional.of(canonicalize(symbol.get())); + String canonical = symbol.getName(); + while (mapping.containsKey(canonical)) { + canonical = mapping.get(canonical); } - return Optional.empty(); + return new Symbol(canonical); } - private Symbol canonicalize(Symbol symbol) + private VariableReferenceExpression canonicalize(VariableReferenceExpression variable) { - Symbol canonical = symbol; + String canonical = variable.getName(); while (mapping.containsKey(canonical)) { canonical = mapping.get(canonical); } - return canonical; + return new VariableReferenceExpression(canonical, types.get(new Symbol(canonical))); } - private List canonicalize(List values) + private Optional canonicalize(Optional variable) { - return values.stream() - .map(this::canonicalize) - .collect(toImmutableList()); + if (variable.isPresent()) { + return Optional.of(canonicalize(variable.get())); + } + return Optional.empty(); } private Expression canonicalize(Expression value) @@ -709,27 +717,40 @@ private List canonicalizeAndDistinct(List outputs) return builder.build(); } + private List canonicalizeAndDistinctVariable(List outputs) + { + Set added = new HashSet<>(); + ImmutableList.Builder builder = ImmutableList.builder(); + for (VariableReferenceExpression variable : outputs) { + VariableReferenceExpression canonical = canonicalize(variable); + if (added.add(canonical)) { + builder.add(canonical); + } + } + return builder.build(); + } + private WindowNode.Specification canonicalizeAndDistinct(WindowNode.Specification specification) { return new WindowNode.Specification( - canonicalizeAndDistinct(specification.getPartitionBy()), + canonicalizeAndDistinctVariable(specification.getPartitionBy()), specification.getOrderingScheme().map(this::canonicalizeAndDistinct)); } private OrderingScheme canonicalizeAndDistinct(OrderingScheme orderingScheme) { - Set added = new HashSet<>(); - ImmutableList.Builder symbols = ImmutableList.builder(); - ImmutableMap.Builder orderings = ImmutableMap.builder(); - for (Symbol symbol : orderingScheme.getOrderBy()) { - Symbol canonical = canonicalize(symbol); + Set added = new HashSet<>(); + ImmutableList.Builder variables = ImmutableList.builder(); + ImmutableMap.Builder orderings = ImmutableMap.builder(); + for (VariableReferenceExpression variable : orderingScheme.getOrderBy()) { + VariableReferenceExpression canonical = canonicalize(variable); if (added.add(canonical)) { - symbols.add(canonical); - orderings.put(canonical, orderingScheme.getOrdering(symbol)); + variables.add(canonical); + orderings.put(canonical, orderingScheme.getOrdering(variable)); } } - return new OrderingScheme(symbols.build(), orderings.build()); + return new OrderingScheme(variables.build(), orderings.build()); } private Set canonicalize(Set symbols) @@ -739,6 +760,13 @@ private Set canonicalize(Set symbols) .collect(toImmutableSet()); } + private Set canonicalizeVariables(Set variables) + { + return variables.stream() + .map(this::canonicalize) + .collect(toImmutableSet()); + } + private List canonicalizeJoinCriteria(List criteria) { ImmutableList.Builder builder = ImmutableList.builder(); @@ -759,14 +787,14 @@ private List canonicalizeIndexJoinCriteria(List canonicalizeSetOperationSymbolMap(ListMultimap setOperationSymbolMap) + private ListMultimap canonicalizeSetOperationVariableMap(ListMultimap setOperationVariableMap) { - ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); - Set addedSymbols = new HashSet<>(); - for (Map.Entry> entry : setOperationSymbolMap.asMap().entrySet()) { - Symbol canonicalOutputSymbol = canonicalize(entry.getKey()); - if (addedSymbols.add(canonicalOutputSymbol)) { - builder.putAll(canonicalOutputSymbol, Iterables.transform(entry.getValue(), this::canonicalize)); + ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); + Set addedSymbols = new HashSet<>(); + for (Map.Entry> entry : setOperationVariableMap.asMap().entrySet()) { + VariableReferenceExpression canonicalOutputVariable = canonicalize(entry.getKey()); + if (addedSymbols.add(canonicalOutputVariable)) { + builder.putAll(canonicalOutputVariable, Iterables.transform(entry.getValue(), this::canonicalize)); } } return builder.build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java index 8411445174ae6..b6863ddd4f643 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.predicate.ValueSet; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.planner.ExpressionDomainTranslator; @@ -162,39 +163,39 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) TupleDomain tupleDomain = fromPredicate(metadata, session, castToExpression(node.getPredicate()), types).getTupleDomain(); if (source instanceof RowNumberNode) { - Symbol rowNumberSymbol = ((RowNumberNode) source).getRowNumberSymbol(); - OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberSymbol); + VariableReferenceExpression rowNumberVariable = ((RowNumberNode) source).getRowNumberVariable(); + OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable); if (upperBound.isPresent()) { source = mergeLimit(((RowNumberNode) source), upperBound.getAsInt()); - return rewriteFilterSource(node, source, rowNumberSymbol, upperBound.getAsInt()); + return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt()); } } else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionManager()) && isOptimizeTopNRowNumber(session)) { WindowNode windowNode = (WindowNode) source; - Symbol rowNumberSymbol = getOnlyElement(windowNode.getWindowFunctions().entrySet()).getKey(); - OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberSymbol); + VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable()); + OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable); if (upperBound.isPresent()) { source = convertToTopNRowNumber(windowNode, upperBound.getAsInt()); - return rewriteFilterSource(node, source, rowNumberSymbol, upperBound.getAsInt()); + return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt()); } } return replaceChildren(node, ImmutableList.of(source)); } - private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode source, Symbol rowNumberSymbol, int upperBound) + private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode source, VariableReferenceExpression rowNumberVariable, int upperBound) { ExtractionResult extractionResult = fromPredicate(metadata, session, castToExpression(filterNode.getPredicate()), types); TupleDomain tupleDomain = extractionResult.getTupleDomain(); - if (!isEqualRange(tupleDomain, rowNumberSymbol, upperBound)) { + if (!isEqualRange(tupleDomain, rowNumberVariable, upperBound)) { return new FilterNode(filterNode.getId(), source, filterNode.getPredicate()); } // Remove the row number domain because it is absorbed into the node Map newDomains = tupleDomain.getDomains().get().entrySet().stream() - .filter(entry -> !entry.getKey().equals(rowNumberSymbol)) + .filter(entry -> !entry.getKey().equals(rowNumberVariable)) .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); // Construct a new predicate @@ -209,22 +210,22 @@ private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode source, Sym return new FilterNode(filterNode.getId(), source, castToRowExpression(newPredicate)); } - private static boolean isEqualRange(TupleDomain tupleDomain, Symbol symbol, long upperBound) + private static boolean isEqualRange(TupleDomain tupleDomain, VariableReferenceExpression variable, long upperBound) { if (tupleDomain.isNone()) { return false; } - Domain domain = tupleDomain.getDomains().get().get(symbol); + Domain domain = tupleDomain.getDomains().get().get(new Symbol(variable.getName())); return domain.getValues().equals(ValueSet.ofRanges(Range.lessThanOrEqual(domain.getType(), upperBound))); } - private static OptionalInt extractUpperBound(TupleDomain tupleDomain, Symbol symbol) + private static OptionalInt extractUpperBound(TupleDomain tupleDomain, VariableReferenceExpression variable) { if (tupleDomain.isNone()) { return OptionalInt.empty(); } - Domain rowNumberDomain = tupleDomain.getDomains().get().get(symbol); + Domain rowNumberDomain = tupleDomain.getDomains().get().get(new Symbol(variable.getName())); if (rowNumberDomain == null) { return OptionalInt.empty(); } @@ -256,7 +257,7 @@ private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPa if (node.getMaxRowCountPerPartition().isPresent()) { newRowCountPerPartition = Math.min(node.getMaxRowCountPerPartition().get(), newRowCountPerPartition); } - return new RowNumberNode(node.getId(), node.getSource(), node.getPartitionBy(), node.getRowNumberSymbol(), Optional.of(newRowCountPerPartition), node.getHashSymbol()); + return new RowNumberNode(node.getId(), node.getSource(), node.getPartitionBy(), node.getRowNumberVariable(), Optional.of(newRowCountPerPartition), node.getHashVariable()); } private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit) @@ -264,7 +265,7 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi return new TopNRowNumberNode(idAllocator.getNextId(), windowNode.getSource(), windowNode.getSpecification(), - getOnlyElement(windowNode.getWindowFunctions().keySet()), + getOnlyElement(windowNode.getCreatedVariable()), limit, false, Optional.empty()); @@ -280,8 +281,8 @@ private static boolean canOptimizeWindowFunction(WindowNode node, FunctionManage if (node.getWindowFunctions().size() != 1) { return false; } - Symbol rowNumberSymbol = getOnlyElement(node.getWindowFunctions().entrySet()).getKey(); - return isRowNumberMetadata(functionManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberSymbol).getFunctionHandle())); + VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet()); + return isRowNumberMetadata(functionManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle())); } private static boolean isRowNumberMetadata(FunctionMetadata functionMetadata) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java index 62d7bb40c0344..5a28ed50b0ff4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java @@ -14,8 +14,10 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType; @@ -42,14 +44,15 @@ public final class WindowNodeUtil { private WindowNodeUtil() {} - public static boolean dependsOn(WindowNode parent, WindowNode child) + public static boolean dependsOn(WindowNode parent, WindowNode child, TypeProvider types) { - return parent.getPartitionBy().stream().anyMatch(child.getCreatedSymbols()::contains) - || (parent.getOrderingScheme().isPresent() && parent.getOrderingScheme().get().getOrderBy().stream().anyMatch(child.getCreatedSymbols()::contains)) + return parent.getPartitionBy().stream().anyMatch(child.getCreatedVariable()::contains) + || (parent.getOrderingScheme().isPresent() && parent.getOrderingScheme().get().getOrderBy().stream() + .anyMatch(child.getCreatedVariable()::contains)) || parent.getWindowFunctions().values().stream() - .map(WindowNodeUtil::extractWindowFunctionUnique) + .map(function -> extractWindowFunctionUniqueVariables(function, types)) .flatMap(Collection::stream) - .anyMatch(child.getCreatedSymbols()::contains); + .anyMatch(child.getCreatedVariable()::contains); } public static WindowType toWindowType(WindowFrame.Type type) @@ -97,4 +100,18 @@ public static Set extractWindowFunctionUnique(WindowNode.Function functi } return builder.build(); } + + public static Set extractWindowFunctionUniqueVariables(WindowNode.Function function, TypeProvider types) + { + ImmutableSet.Builder builder = ImmutableSet.builder(); + for (RowExpression argument : function.getFunctionCall().getArguments()) { + if (isExpression(argument)) { + builder.addAll(SymbolsExtractor.extractAllVariable(castToExpression(argument), types)); + } + else { + builder.addAll(SymbolsExtractor.extractAll(argument)); + } + } + return builder.build(); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java index e29496203d4dc..901a248400342 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.optimizations.joins; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.FilterNode; @@ -49,7 +49,7 @@ */ public class JoinGraph { - private final Optional> assignments; + private final Optional> assignments; private final List filters; private final List nodes; // nodes in order of their appearance in tree plan (left, right, parent) private final Multimap edges; @@ -92,7 +92,7 @@ public JoinGraph( Multimap edges, PlanNodeId rootId, List filters, - Optional> assignments) + Optional> assignments) { this.nodes = nodes; this.edges = edges; @@ -101,12 +101,12 @@ public JoinGraph( this.assignments = assignments; } - public JoinGraph withAssignments(Map assignments) + public JoinGraph withAssignments(Map assignments) { return new JoinGraph(nodes, edges, rootId, filters, Optional.of(assignments)); } - public Optional> getAssignments() + public Optional> getAssignments() { return assignments; } @@ -204,15 +204,15 @@ private JoinGraph joinWith(JoinGraph other, List joinCl .build(); for (JoinNode.EquiJoinClause edge : joinClauses) { - Symbol leftSymbol = edge.getLeft(); - Symbol rightSymbol = edge.getRight(); - checkState(context.containsSymbol(leftSymbol)); - checkState(context.containsSymbol(rightSymbol)); - - PlanNode left = context.getSymbolSource(leftSymbol); - PlanNode right = context.getSymbolSource(rightSymbol); - edges.put(left.getId(), new Edge(right, leftSymbol, rightSymbol)); - edges.put(right.getId(), new Edge(left, rightSymbol, leftSymbol)); + VariableReferenceExpression leftVariable = edge.getLeft(); + VariableReferenceExpression rightVariable = edge.getRight(); + checkState(context.containsVariable(leftVariable)); + checkState(context.containsVariable(rightVariable)); + + PlanNode left = context.getVariableSource(leftVariable); + PlanNode right = context.getVariableSource(rightVariable); + edges.put(left.getId(), new Edge(right, leftVariable, rightVariable)); + edges.put(right.getId(), new Edge(left, rightVariable, leftVariable)); } return new JoinGraph(nodes, edges.build(), newRoot, joinedFilters, Optional.empty()); @@ -244,8 +244,8 @@ protected JoinGraph visitPlan(PlanNode node, Context context) } } - for (Symbol symbol : node.getOutputSymbols()) { - context.setSymbolSource(symbol, node); + for (VariableReferenceExpression variable : node.getOutputVariables()) { + context.setVariableSource(variable, node); } return new JoinGraph(node); } @@ -305,11 +305,11 @@ private boolean isTrivialGraph(JoinGraph graph) private JoinGraph replacementGraph(PlanNode oldNode, PlanNode newNode, Context context) { // TODO optimize when idea is generally approved - List symbols = context.symbolSources.entrySet().stream() + List variables = context.variableSources.entrySet().stream() .filter(entry -> entry.getValue() == oldNode) .map(Map.Entry::getKey) .collect(toImmutableList()); - symbols.forEach(symbol -> context.symbolSources.put(symbol, newNode)); + variables.forEach(variable -> context.variableSources.put(variable, newNode)); return new JoinGraph(newNode); } @@ -318,14 +318,14 @@ private JoinGraph replacementGraph(PlanNode oldNode, PlanNode newNode, Context c public static class Edge { private final PlanNode targetNode; - private final Symbol sourceSymbol; - private final Symbol targetSymbol; + private final VariableReferenceExpression sourceVariable; + private final VariableReferenceExpression targetVariable; - public Edge(PlanNode targetNode, Symbol sourceSymbol, Symbol targetSymbol) + public Edge(PlanNode targetNode, VariableReferenceExpression sourceVariable, VariableReferenceExpression targetVariable) { this.targetNode = requireNonNull(targetNode, "targetNode is null"); - this.sourceSymbol = requireNonNull(sourceSymbol, "sourceSymbol is null"); - this.targetSymbol = requireNonNull(targetSymbol, "targetSymbol is null"); + this.sourceVariable = requireNonNull(sourceVariable, "sourceVariable is null"); + this.targetVariable = requireNonNull(targetVariable, "targetVariable is null"); } public PlanNode getTargetNode() @@ -333,27 +333,27 @@ public PlanNode getTargetNode() return targetNode; } - public Symbol getSourceSymbol() + public VariableReferenceExpression getSourceVariable() { - return sourceSymbol; + return sourceVariable; } - public Symbol getTargetSymbol() + public VariableReferenceExpression getTargetVariable() { - return targetSymbol; + return targetVariable; } } private static class Context { - private final Map symbolSources = new HashMap<>(); + private final Map variableSources = new HashMap<>(); // TODO When com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins is removed, remove 'joinGraphs' private final List joinGraphs = new ArrayList<>(); - public void setSymbolSource(Symbol symbol, PlanNode node) + public void setVariableSource(VariableReferenceExpression variable, PlanNode node) { - symbolSources.put(symbol, node); + variableSources.put(variable, node); } public void addSubGraph(JoinGraph graph) @@ -361,15 +361,15 @@ public void addSubGraph(JoinGraph graph) joinGraphs.add(graph); } - public boolean containsSymbol(Symbol symbol) + public boolean containsVariable(VariableReferenceExpression variable) { - return symbolSources.containsKey(symbol); + return variableSources.containsKey(variable); } - public PlanNode getSymbolSource(Symbol symbol) + public PlanNode getVariableSource(VariableReferenceExpression variable) { - checkState(containsSymbol(symbol)); - return symbolSources.get(symbol); + checkState(containsVariable(variable)); + return variableSources.get(variable); } public List getGraphs() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java index 26a4a9d28721c..c04f593cedcda 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java @@ -17,8 +17,8 @@ import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.tree.Expression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -44,24 +44,24 @@ public class AggregationNode extends InternalPlanNode { private final PlanNode source; - private final Map aggregations; + private final Map aggregations; private final GroupingSetDescriptor groupingSets; - private final List preGroupedSymbols; + private final List preGroupedVariables; private final Step step; - private final Optional hashSymbol; - private final Optional groupIdSymbol; - private final List outputs; + private final Optional hashVariable; + private final Optional groupIdVariable; + private final List outputs; @JsonCreator public AggregationNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("aggregations") Map aggregations, + @JsonProperty("aggregations") Map aggregations, @JsonProperty("groupingSets") GroupingSetDescriptor groupingSets, - @JsonProperty("preGroupedSymbols") List preGroupedSymbols, + @JsonProperty("preGroupedVariables") List preGroupedVariables, @JsonProperty("step") Step step, - @JsonProperty("hashSymbol") Optional hashSymbol, - @JsonProperty("groupIdSymbol") Optional groupIdSymbol) + @JsonProperty("hashVariable") Optional hashVariable, + @JsonProperty("groupIdVariable") Optional groupIdVariable) { super(id); @@ -69,10 +69,10 @@ public AggregationNode( this.aggregations = ImmutableMap.copyOf(requireNonNull(aggregations, "aggregations is null")); requireNonNull(groupingSets, "groupingSets is null"); - groupIdSymbol.ifPresent(symbol -> checkArgument(groupingSets.getGroupingKeys().contains(symbol), "Grouping columns does not contain groupId column")); + groupIdVariable.ifPresent(variable -> checkArgument(groupingSets.getGroupingKeys().contains(variable), "Grouping columns does not contain groupId column")); this.groupingSets = groupingSets; - this.groupIdSymbol = requireNonNull(groupIdSymbol); + this.groupIdVariable = requireNonNull(groupIdVariable); boolean noOrderBy = aggregations.values().stream() .map(Aggregation::getOrderBy) @@ -80,26 +80,26 @@ public AggregationNode( checkArgument(noOrderBy || step == SINGLE, "ORDER BY does not support distributed aggregation"); this.step = step; - this.hashSymbol = hashSymbol; + this.hashVariable = hashVariable; - requireNonNull(preGroupedSymbols, "preGroupedSymbols is null"); - checkArgument(preGroupedSymbols.isEmpty() || groupingSets.getGroupingKeys().containsAll(preGroupedSymbols), "Pre-grouped symbols must be a subset of the grouping keys"); - this.preGroupedSymbols = ImmutableList.copyOf(preGroupedSymbols); + requireNonNull(preGroupedVariables, "preGroupedVariables is null"); + checkArgument(preGroupedVariables.isEmpty() || groupingSets.getGroupingKeys().containsAll(preGroupedVariables), "Pre-grouped variables must be a subset of the grouping keys"); + this.preGroupedVariables = ImmutableList.copyOf(preGroupedVariables); - ImmutableList.Builder outputs = ImmutableList.builder(); + ImmutableList.Builder outputs = ImmutableList.builder(); outputs.addAll(groupingSets.getGroupingKeys()); - hashSymbol.ifPresent(outputs::add); + hashVariable.ifPresent(outputs::add); outputs.addAll(aggregations.keySet()); this.outputs = outputs.build(); } - public List getGroupingKeys() + public List getGroupingKeys() { return groupingSets.getGroupingKeys(); } - @JsonProperty("groupingSets") + @JsonProperty public GroupingSetDescriptor getGroupingSets() { return groupingSets; @@ -135,21 +135,21 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { return outputs; } @JsonProperty - public Map getAggregations() + public Map getAggregations() { return aggregations; } - @JsonProperty("preGroupedSymbols") - public List getPreGroupedSymbols() + @JsonProperty + public List getPreGroupedVariables() { - return preGroupedSymbols; + return preGroupedVariables; } public int getGroupingSetCount() @@ -162,28 +162,28 @@ public Set getGlobalGroupingSets() return groupingSets.getGlobalGroupingSets(); } - @JsonProperty("source") + @JsonProperty public PlanNode getSource() { return source; } - @JsonProperty("step") + @JsonProperty public Step getStep() { return step; } - @JsonProperty("hashSymbol") - public Optional getHashSymbol() + @JsonProperty + public Optional getHashVariable() { - return hashSymbol; + return hashVariable; } - @JsonProperty("groupIdSymbol") - public Optional getGroupIdSymbol() + @JsonProperty + public Optional getGroupIdVariable() { - return groupIdSymbol; + return groupIdVariable; } public boolean hasOrderings() @@ -202,7 +202,7 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new AggregationNode(getId(), Iterables.getOnlyElement(newChildren), aggregations, groupingSets, preGroupedSymbols, step, hashSymbol, groupIdSymbol); + return new AggregationNode(getId(), Iterables.getOnlyElement(newChildren), aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable); } public boolean isDecomposable(FunctionManager functionManager) @@ -239,7 +239,7 @@ public boolean hasSingleNodeExecutionPreference(FunctionManager functionManager) public boolean isStreamable() { - return !preGroupedSymbols.isEmpty() && groupingSets.getGroupingSetCount() == 1 && groupingSets.getGlobalGroupingSets().isEmpty(); + return !preGroupedVariables.isEmpty() && groupingSets.getGroupingSetCount() == 1 && groupingSets.getGlobalGroupingSets().isEmpty(); } public static GroupingSetDescriptor globalAggregation() @@ -247,7 +247,7 @@ public static GroupingSetDescriptor globalAggregation() return singleGroupingSet(ImmutableList.of()); } - public static GroupingSetDescriptor singleGroupingSet(List groupingKeys) + public static GroupingSetDescriptor singleGroupingSet(List groupingKeys) { Set globalGroupingSets; if (groupingKeys.isEmpty()) { @@ -260,20 +260,20 @@ public static GroupingSetDescriptor singleGroupingSet(List groupingKeys) return new GroupingSetDescriptor(groupingKeys, 1, globalGroupingSets); } - public static GroupingSetDescriptor groupingSets(List groupingKeys, int groupingSetCount, Set globalGroupingSets) + public static GroupingSetDescriptor groupingSets(List groupingKeys, int groupingSetCount, Set globalGroupingSets) { return new GroupingSetDescriptor(groupingKeys, groupingSetCount, globalGroupingSets); } public static class GroupingSetDescriptor { - private final List groupingKeys; + private final List groupingKeys; private final int groupingSetCount; private final Set globalGroupingSets; @JsonCreator public GroupingSetDescriptor( - @JsonProperty("groupingKeys") List groupingKeys, + @JsonProperty("groupingKeys") List groupingKeys, @JsonProperty("groupingSetCount") int groupingSetCount, @JsonProperty("globalGroupingSets") Set globalGroupingSets) { @@ -290,7 +290,7 @@ public GroupingSetDescriptor( } @JsonProperty - public List getGroupingKeys() + public List getGroupingKeys() { return groupingKeys; } @@ -362,7 +362,7 @@ public static class Aggregation private final Optional filter; private final Optional orderingScheme; private final boolean isDistinct; - private final Optional mask; + private final Optional mask; @JsonCreator public Aggregation( @@ -371,7 +371,7 @@ public Aggregation( @JsonProperty("filter") Optional filter, @JsonProperty("orderBy") Optional orderingScheme, @JsonProperty("distinct") boolean isDistinct, - @JsonProperty("mask") Optional mask) + @JsonProperty("mask") Optional mask) { this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); this.arguments = requireNonNull(arguments, "arguments is null"); @@ -412,7 +412,7 @@ public boolean isDistinct() } @JsonProperty - public Optional getMask() + public Optional getMask() { return mask; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java index c9964499385c0..cce5d75a63442 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.InPredicate; @@ -38,9 +38,9 @@ public class ApplyNode private final PlanNode subquery; /** - * Correlation symbols, returned from input (outer plan) used in subquery (inner plan) + * Correlation variables, returned from input (outer plan) used in subquery (inner plan) */ - private final List correlation; + private final List correlation; /** * Expressions that use subquery symbols. @@ -74,7 +74,7 @@ public ApplyNode( @JsonProperty("input") PlanNode input, @JsonProperty("subquery") PlanNode subquery, @JsonProperty("subqueryAssignments") Assignments subqueryAssignments, - @JsonProperty("correlation") List correlation, + @JsonProperty("correlation") List correlation, @JsonProperty("originSubqueryError") String originSubqueryError) { super(id); @@ -84,7 +84,7 @@ public ApplyNode( requireNonNull(correlation, "correlation is null"); requireNonNull(originSubqueryError, "originSubqueryError is null"); - checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation"); + checkArgument(input.getOutputVariables().containsAll(correlation), "Input does not contain symbols from correlation"); checkArgument( subqueryAssignments.getExpressions().stream().allMatch(ApplyNode::isSupportedSubqueryExpression), "Unexpected expression used for subquery expression"); @@ -103,31 +103,31 @@ private static boolean isSupportedSubqueryExpression(Expression expression) expression instanceof QuantifiedComparisonExpression; } - @JsonProperty("input") + @JsonProperty public PlanNode getInput() { return input; } - @JsonProperty("subquery") + @JsonProperty public PlanNode getSubquery() { return subquery; } - @JsonProperty("subqueryAssignments") + @JsonProperty public Assignments getSubqueryAssignments() { return subqueryAssignments; } - @JsonProperty("correlation") - public List getCorrelation() + @JsonProperty + public List getCorrelation() { return correlation; } - @JsonProperty("originSubqueryError") + @JsonProperty public String getOriginSubqueryError() { return originSubqueryError; @@ -140,11 +140,10 @@ public List getSources() } @Override - @JsonProperty("outputSymbols") - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.builder() - .addAll(input.getOutputSymbols()) + return ImmutableList.builder() + .addAll(input.getOutputVariables()) .addAll(subqueryAssignments.getOutputs()) .build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignUniqueId.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignUniqueId.java index 56959ff8c9712..f901ff9ff3aba 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignUniqueId.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignUniqueId.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -29,25 +29,25 @@ public class AssignUniqueId extends InternalPlanNode { private final PlanNode source; - private final Symbol idColumn; + private final VariableReferenceExpression idVariable; @JsonCreator public AssignUniqueId( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("idColumn") Symbol unique) + @JsonProperty("idVariable") VariableReferenceExpression idVariable) { super(id); this.source = requireNonNull(source, "source is null"); - this.idColumn = requireNonNull(unique, "idColumn is null"); + this.idVariable = requireNonNull(idVariable, "idVariable is null"); } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.builder() - .addAll(source.getOutputSymbols()) - .add(idColumn) + return ImmutableList.builder() + .addAll(source.getOutputVariables()) + .add(idVariable) .build(); } @@ -64,9 +64,9 @@ public List getSources() } @JsonProperty - public Symbol getIdColumn() + public VariableReferenceExpression getIdVariable() { - return idColumn; + return idVariable; } @Override @@ -79,6 +79,6 @@ public R accept(InternalPlanVisitor visitor, C context) public PlanNode replaceChildren(List newChildren) { checkArgument(newChildren.size() == 1, "expected newChildren to contain 1 node"); - return new AssignUniqueId(getId(), Iterables.getOnlyElement(newChildren), idColumn); + return new AssignUniqueId(getId(), Iterables.getOnlyElement(newChildren), idVariable); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java index 749613cd80d8c..8aacfd8275637 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.plan; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -36,6 +37,8 @@ import java.util.stream.Collector; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; @@ -46,17 +49,17 @@ public static Builder builder() return new Builder(); } - public static Assignments identity(Symbol... symbols) + public static Assignments identity(VariableReferenceExpression... variables) { - return identity(asList(symbols)); + return identity(asList(variables)); } - public static Assignments identity(Iterable symbols) + public static Assignments identity(Iterable variables) { - return builder().putIdentities(symbols).build(); + return builder().putIdentities(variables).build(); } - public static Assignments copyOf(Map assignments) + public static Assignments copyOf(Map assignments) { return builder() .putAll(assignments) @@ -68,31 +71,36 @@ public static Assignments of() return builder().build(); } - public static Assignments of(Symbol symbol, Expression expression) + public static Assignments of(VariableReferenceExpression variable, Expression expression) { - return builder().put(symbol, expression).build(); + return builder().put(variable, expression).build(); } - public static Assignments of(Symbol symbol1, Expression expression1, Symbol symbol2, Expression expression2) + public static Assignments of(VariableReferenceExpression variable1, Expression expression1, VariableReferenceExpression variable2, Expression expression2) { - return builder().put(symbol1, expression1).put(symbol2, expression2).build(); + return builder().put(variable1, expression1).put(variable2, expression2).build(); } - private final Map assignments; + private final Map assignments; @JsonCreator - public Assignments(@JsonProperty("assignments") Map assignments) + public Assignments(@JsonProperty("assignments") Map assignments) { this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); } - public List getOutputs() + public List getOutputSymbols() + { + return assignments.keySet().stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableList()); + } + + public List getOutputs() { return ImmutableList.copyOf(assignments.keySet()); } @JsonProperty("assignments") - public Map getMap() + public Map getMap() { return assignments; } @@ -109,26 +117,26 @@ public Assignments rewrite(Function rewrite) .collect(toAssignments()); } - public Assignments filter(Collection symbols) + public Assignments filter(Collection variables) { - return filter(symbols::contains); + return filter(variables::contains); } - public Assignments filter(Predicate predicate) + public Assignments filter(Predicate predicate) { return assignments.entrySet().stream() .filter(entry -> predicate.apply(entry.getKey())) .collect(toAssignments()); } - public boolean isIdentity(Symbol output) + public boolean isIdentity(VariableReferenceExpression output) { Expression expression = assignments.get(output); return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); } - private Collector, Builder, Assignments> toAssignments() + private Collector, Builder, Assignments> toAssignments() { return Collector.of( Assignments::builder, @@ -146,18 +154,35 @@ public Collection getExpressions() } public Set getSymbols() + { + return assignments.keySet().stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableSet()); + } + + public Set getVariables() { return assignments.keySet(); } - public Set> entrySet() + public Set> entrySet() { return assignments.entrySet(); } + public Expression get(VariableReferenceExpression variable) + { + return assignments.get(variable); + } + public Expression get(Symbol symbol) { - return assignments.get(symbol); + List candidate = assignments.entrySet().stream() + .filter(entry -> entry.getKey().getName().equals(symbol.getName())) + .map(Entry::getValue) + .collect(toImmutableList()); + if (candidate.isEmpty()) { + return null; + } + return candidate.get(0); } public int size() @@ -170,7 +195,7 @@ public boolean isEmpty() return size() == 0; } - public void forEach(BiConsumer consumer) + public void forEach(BiConsumer consumer) { assignments.forEach(consumer); } @@ -198,53 +223,53 @@ public int hashCode() public static class Builder { - private final Map assignments = new LinkedHashMap<>(); + private final Map assignments = new LinkedHashMap<>(); public Builder putAll(Assignments assignments) { return putAll(assignments.getMap()); } - public Builder putAll(Map assignments) + public Builder putAll(Map assignments) { - for (Entry assignment : assignments.entrySet()) { + for (Entry assignment : assignments.entrySet()) { put(assignment.getKey(), assignment.getValue()); } return this; } - public Builder put(Symbol symbol, Expression expression) + public Builder put(VariableReferenceExpression variable, Expression expression) { - if (assignments.containsKey(symbol)) { - Expression assignment = assignments.get(symbol); + if (assignments.containsKey(variable)) { + Expression assignment = assignments.get(variable); checkState( assignment.equals(expression), - "Symbol %s already has assignment %s, while adding %s", - symbol, + "Variable %s already has assignment %s, while adding %s", + variable, assignment, expression); } - assignments.put(symbol, expression); + assignments.put(variable, expression); return this; } - public Builder put(Entry assignment) + public Builder put(Entry assignment) { put(assignment.getKey(), assignment.getValue()); return this; } - public Builder putIdentities(Iterable symbols) + public Builder putIdentities(Iterable variables) { - for (Symbol symbol : symbols) { - putIdentity(symbol); + for (VariableReferenceExpression variable : variables) { + putIdentity(variable); } return this; } - public Builder putIdentity(Symbol symbol) + public Builder putIdentity(VariableReferenceExpression variable) { - put(symbol, symbol.toSymbolReference()); + put(variable, new SymbolReference(variable.getName())); return this; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DeleteNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DeleteNode.java index 0e579e6cec97c..f4ed7dcc588ba 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DeleteNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DeleteNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.TableWriterNode.DeleteHandle; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -33,23 +33,23 @@ public class DeleteNode { private final PlanNode source; private final DeleteHandle target; - private final Symbol rowId; - private final List outputs; + private final VariableReferenceExpression rowId; + private final List outputVariables; @JsonCreator public DeleteNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("target") DeleteHandle target, - @JsonProperty("rowId") Symbol rowId, - @JsonProperty("outputs") List outputs) + @JsonProperty("rowId") VariableReferenceExpression rowId, + @JsonProperty("outputVariables") List outputVariables) { super(id); this.source = requireNonNull(source, "source is null"); this.target = requireNonNull(target, "target is null"); this.rowId = requireNonNull(rowId, "rowId is null"); - this.outputs = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); + this.outputVariables = ImmutableList.copyOf(requireNonNull(outputVariables, "outputVariables is null")); } @JsonProperty @@ -65,16 +65,16 @@ public DeleteHandle getTarget() } @JsonProperty - public Symbol getRowId() + public VariableReferenceExpression getRowId() { return rowId; } - @JsonProperty("outputs") + @JsonProperty @Override - public List getOutputSymbols() + public List getOutputVariables() { - return outputs; + return outputVariables; } @Override @@ -92,6 +92,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new DeleteNode(getId(), Iterables.getOnlyElement(newChildren), target, rowId, outputs); + return new DeleteNode(getId(), Iterables.getOnlyElement(newChildren), target, rowId, outputVariables); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DistinctLimitNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DistinctLimitNode.java index f18aa3aaf91d5..51b708863dbb6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DistinctLimitNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DistinctLimitNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -35,8 +35,8 @@ public class DistinctLimitNode private final PlanNode source; private final long limit; private final boolean partial; - private final List distinctSymbols; - private final Optional hashSymbol; + private final List distinctVariables; + private final Optional hashVariable; @JsonCreator public DistinctLimitNode( @@ -44,17 +44,17 @@ public DistinctLimitNode( @JsonProperty("source") PlanNode source, @JsonProperty("limit") long limit, @JsonProperty("partial") boolean partial, - @JsonProperty("distinctSymbols") List distinctSymbols, - @JsonProperty("hashSymbol") Optional hashSymbol) + @JsonProperty("distinctVariables") List distinctVariables, + @JsonProperty("hashVariable") Optional hashVariable) { super(id); this.source = requireNonNull(source, "source is null"); checkArgument(limit >= 0, "limit must be greater than or equal to zero"); this.limit = limit; this.partial = partial; - this.distinctSymbols = ImmutableList.copyOf(distinctSymbols); - this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); - checkArgument(!hashSymbol.isPresent() || !distinctSymbols.contains(hashSymbol.get()), "distinctSymbols should not contain hash symbol"); + this.distinctVariables = ImmutableList.copyOf(distinctVariables); + this.hashVariable = requireNonNull(hashVariable, "hashVariable is null"); + checkArgument(!hashVariable.isPresent() || !distinctVariables.contains(hashVariable.get()), "distinctVariables should not contain hash variable"); } @Override @@ -82,26 +82,25 @@ public boolean isPartial() } @JsonProperty - public Optional getHashSymbol() + public Optional getHashVariable() { - return hashSymbol; + return hashVariable; } @JsonProperty - public List getDistinctSymbols() + public List getDistinctVariables() { - return distinctSymbols; + return distinctVariables; } @Override - public List getOutputSymbols() + public List getOutputVariables() { - ImmutableList.Builder outputSymbols = ImmutableList.builder(); - outputSymbols.addAll(distinctSymbols); - hashSymbol.ifPresent(outputSymbols::add); - return outputSymbols.build(); + ImmutableList.Builder outputVariables = ImmutableList.builder(); + outputVariables.addAll(distinctVariables); + hashVariable.ifPresent(outputVariables::add); + return outputVariables.build(); } - @Override public R accept(InternalPlanVisitor visitor, C context) { @@ -111,6 +110,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new DistinctLimitNode(getId(), Iterables.getOnlyElement(newChildren), limit, partial, distinctSymbols, hashSymbol); + return new DistinctLimitNode(getId(), Iterables.getOnlyElement(newChildren), limit, partial, distinctVariables, hashVariable); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java index 02e386f975183..6847d3542bce7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -49,9 +49,9 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return source.getOutputSymbols(); + return source.getOutputVariables(); } @JsonProperty("source") diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExceptNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExceptNode.java index 3cf21a6426e24..38ea7bb476de7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExceptNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExceptNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ListMultimap; @@ -29,10 +29,9 @@ public class ExceptNode public ExceptNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("sources") List sources, - @JsonProperty("outputToInputs") ListMultimap outputToInputs, - @JsonProperty("outputs") List outputs) + @JsonProperty("outputToInputs") ListMultimap outputToInputs) { - super(id, sources, outputToInputs, outputs); + super(id, sources, outputToInputs); } @Override @@ -44,6 +43,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new ExceptNode(getId(), newChildren, getSymbolMapping(), getOutputSymbols()); + return new ExceptNode(getId(), newChildren, getVariableMapping()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java index a14a538a24d5f..541560b096659 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java @@ -14,16 +14,15 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; import com.facebook.presto.sql.planner.PartitioningHandle; import com.facebook.presto.sql.planner.PartitioningScheme; -import com.facebook.presto.sql.planner.Symbol; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import javax.annotation.concurrent.Immutable; @@ -87,7 +86,7 @@ public boolean isLocal() private final PartitioningScheme partitioningScheme; // for each source, the list of inputs corresponding to each output - private final List> inputs; + private final List> inputs; private final Optional orderingScheme; @@ -98,7 +97,7 @@ public ExchangeNode( @JsonProperty("scope") Scope scope, @JsonProperty("partitioningScheme") PartitioningScheme partitioningScheme, @JsonProperty("sources") List sources, - @JsonProperty("inputs") List> inputs, + @JsonProperty("inputs") List> inputs, @JsonProperty("orderingScheme") Optional orderingScheme) { super(id); @@ -111,10 +110,10 @@ public ExchangeNode( requireNonNull(orderingScheme, "orderingScheme is null"); checkArgument(!inputs.isEmpty(), "inputs is empty"); - checkArgument(inputs.stream().allMatch(inputSymbols -> inputSymbols.size() == partitioningScheme.getOutputLayout().size()), "Input symbols do not match output symbols"); + checkArgument(inputs.stream().allMatch(inputVariables -> inputVariables.size() == partitioningScheme.getOutputLayout().size()), "Input symbols do not match output symbols"); checkArgument(inputs.size() == sources.size(), "Must have same number of input lists as sources"); for (int i = 0; i < inputs.size(); i++) { - checkArgument(ImmutableSet.copyOf(sources.get(i).getOutputSymbols()).containsAll(inputs.get(i)), "Source does not supply all required input symbols"); + checkArgument(sources.get(i).getOutputVariables().containsAll(inputs.get(i)), "Source does not supply all required input variables"); } checkArgument(!scope.isLocal() || partitioningScheme.getPartitioning().getArguments().stream().allMatch(ArgumentBinding::isVariable), @@ -137,12 +136,12 @@ public ExchangeNode( this.orderingScheme = orderingScheme; } - public static ExchangeNode systemPartitionedExchange(PlanNodeId id, Scope scope, PlanNode child, List partitioningColumns, Optional hashColumn) + public static ExchangeNode systemPartitionedExchange(PlanNodeId id, Scope scope, PlanNode child, List partitioningColumns, Optional hashColumn) { return systemPartitionedExchange(id, scope, child, partitioningColumns, hashColumn, false); } - public static ExchangeNode systemPartitionedExchange(PlanNodeId id, Scope scope, PlanNode child, List partitioningColumns, Optional hashColumn, boolean replicateNullsAndAny) + public static ExchangeNode systemPartitionedExchange(PlanNodeId id, Scope scope, PlanNode child, List partitioningColumns, Optional hashColumn, boolean replicateNullsAndAny) { return partitionedExchange( id, @@ -153,12 +152,12 @@ public static ExchangeNode systemPartitionedExchange(PlanNodeId id, Scope scope, replicateNullsAndAny); } - public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanNode child, Partitioning partitioning, Optional hashColumn) + public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanNode child, Partitioning partitioning, Optional hashColumn) { return partitionedExchange(id, scope, child, partitioning, hashColumn, false); } - public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanNode child, Partitioning partitioning, Optional hashColumn, boolean replicateNullsAndAny) + public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanNode child, Partitioning partitioning, Optional hashColumn, boolean replicateNullsAndAny) { return partitionedExchange( id, @@ -166,7 +165,7 @@ public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanN child, new PartitioningScheme( partitioning, - child.getOutputSymbols(), + child.getOutputVariables(), hashColumn, replicateNullsAndAny, Optional.empty())); @@ -193,9 +192,9 @@ public static ExchangeNode replicatedExchange(PlanNodeId id, Scope scope, PlanNo id, REPLICATE, scope, - new PartitioningScheme(Partitioning.create(FIXED_BROADCAST_DISTRIBUTION, ImmutableList.of()), child.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(FIXED_BROADCAST_DISTRIBUTION, ImmutableList.of()), child.getOutputVariables()), ImmutableList.of(child), - ImmutableList.of(child.getOutputSymbols()), + ImmutableList.of(child.getOutputVariables()), Optional.empty()); } @@ -205,9 +204,9 @@ public static ExchangeNode gatheringExchange(PlanNodeId id, Scope scope, PlanNod id, GATHER, scope, - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), child.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), child.getOutputVariables()), ImmutableList.of(child), - ImmutableList.of(child.getOutputSymbols()), + ImmutableList.of(child.getOutputVariables()), Optional.empty()); } @@ -217,7 +216,7 @@ public static ExchangeNode roundRobinExchange(PlanNodeId id, Scope scope, PlanNo id, scope, child, - new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), child.getOutputSymbols())); + new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), child.getOutputVariables())); } public static ExchangeNode mergingExchange(PlanNodeId id, Scope scope, PlanNode child, OrderingScheme orderingScheme) @@ -227,9 +226,9 @@ public static ExchangeNode mergingExchange(PlanNodeId id, Scope scope, PlanNode id, GATHER, scope, - new PartitioningScheme(Partitioning.create(partitioningHandle, ImmutableList.of()), child.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(partitioningHandle, ImmutableList.of()), child.getOutputVariables()), ImmutableList.of(child), - ImmutableList.of(child.getOutputSymbols()), + ImmutableList.of(child.getOutputVariables()), Optional.of(orderingScheme)); } @@ -253,7 +252,7 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { return partitioningScheme.getOutputLayout(); } @@ -271,7 +270,7 @@ public Optional getOrderingScheme() } @JsonProperty - public List> getInputs() + public List> getInputs() { return inputs; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java index b5389d33e305b..21d74727e0dca 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -31,26 +31,26 @@ public class ExplainAnalyzeNode extends InternalPlanNode { private final PlanNode source; - private final Symbol outputSymbol; + private final VariableReferenceExpression outputVariable; private final boolean verbose; @JsonCreator public ExplainAnalyzeNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("outputSymbol") Symbol outputSymbol, + @JsonProperty("outputVariable")VariableReferenceExpression outputVariable, @JsonProperty("verbose") boolean verbose) { super(id); this.source = requireNonNull(source, "source is null"); - this.outputSymbol = requireNonNull(outputSymbol, "outputSymbol is null"); + this.outputVariable = requireNonNull(outputVariable, "outputVariable is null"); this.verbose = verbose; } - @JsonProperty("outputSymbol") - public Symbol getOutputSymbol() + @JsonProperty("outputVariable") + public VariableReferenceExpression getOutputVariable() { - return outputSymbol; + return outputVariable; } @JsonProperty("source") @@ -66,9 +66,9 @@ public boolean isVerbose() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.of(outputSymbol); + return ImmutableList.of(outputVariable); } @Override @@ -86,6 +86,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new ExplainAnalyzeNode(getId(), Iterables.getOnlyElement(newChildren), outputSymbol, isVerbose()); + return new ExplainAnalyzeNode(getId(), Iterables.getOnlyElement(newChildren), outputVariable, isVerbose()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/FilterNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/FilterNode.java index 3eff8498e0941..44708533bf5c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/FilterNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/FilterNode.java @@ -15,7 +15,7 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -48,7 +48,7 @@ public FilterNode(@JsonProperty("id") PlanNodeId id, * Get the predicate (a RowExpression of boolean type) of the FilterNode. * It serves as the criteria to determine whether the incoming rows should be filtered out or not. */ - @JsonProperty("predicate") + @JsonProperty public RowExpression getPredicate() { return predicate; @@ -63,10 +63,9 @@ public PlanNode getSource() return source; } - @Override - public List getOutputSymbols() + public List getOutputVariables() { - return source.getOutputSymbols(); + return source.getOutputVariables(); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java index bef1d81d66494..e879ba275c266 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -43,43 +43,43 @@ public class GroupIdNode { private final PlanNode source; - // in terms of output symbols - private final List> groupingSets; + // in terms of output variables + private final List> groupingSets; // tracks how each grouping set column is derived from an input column - private final Map groupingColumns; - private final List aggregationArguments; + private final Map groupingColumns; + private final List aggregationArguments; - private final Symbol groupIdSymbol; + private final VariableReferenceExpression groupIdVariable; @JsonCreator public GroupIdNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("groupingSets") List> groupingSets, - @JsonProperty("groupingColumns") Map groupingColumns, - @JsonProperty("aggregationArguments") List aggregationArguments, - @JsonProperty("groupIdSymbol") Symbol groupIdSymbol) + @JsonProperty("groupingSets") List> groupingSets, + @JsonProperty("groupingColumns") Map groupingColumns, + @JsonProperty("aggregationArguments") List aggregationArguments, + @JsonProperty("groupIdVariable") VariableReferenceExpression groupIdVariable) { super(id); this.source = requireNonNull(source); this.groupingSets = listOfListsCopy(requireNonNull(groupingSets, "groupingSets is null")); this.groupingColumns = ImmutableMap.copyOf(requireNonNull(groupingColumns)); this.aggregationArguments = ImmutableList.copyOf(aggregationArguments); - this.groupIdSymbol = requireNonNull(groupIdSymbol); + this.groupIdVariable = requireNonNull(groupIdVariable); checkArgument(Sets.intersection(groupingColumns.keySet(), ImmutableSet.copyOf(aggregationArguments)).isEmpty(), "aggregation columns and grouping set columns must be a disjoint set"); } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.builder() + return ImmutableList.builder() .addAll(groupingSets.stream() .flatMap(Collection::stream) .collect(toSet())) .addAll(aggregationArguments) - .add(groupIdSymbol) + .add(groupIdVariable) .build(); } @@ -96,27 +96,27 @@ public PlanNode getSource() } @JsonProperty - public List> getGroupingSets() + public List> getGroupingSets() { return groupingSets; } @JsonProperty - public Map getGroupingColumns() + public Map getGroupingColumns() { return groupingColumns; } @JsonProperty - public List getAggregationArguments() + public List getAggregationArguments() { return aggregationArguments; } @JsonProperty - public Symbol getGroupIdSymbol() + public VariableReferenceExpression getGroupIdVariable() { - return groupIdSymbol; + return groupIdVariable; } @Override @@ -125,9 +125,9 @@ public R accept(InternalPlanVisitor visitor, C context) return visitor.visitGroupId(this, context); } - public Set getInputSymbols() + public Set getInputVariables() { - return ImmutableSet.builder() + return ImmutableSet.builder() .addAll(aggregationArguments) .addAll(groupingSets.stream() .map(set -> set.stream() @@ -138,9 +138,9 @@ public Set getInputSymbols() } // returns the common grouping columns in terms of output symbols - public Set getCommonGroupingColumns() + public Set getCommonGroupingColumns() { - Set intersection = new HashSet<>(groupingSets.get(0)); + Set intersection = new HashSet<>(groupingSets.get(0)); for (int i = 1; i < groupingSets.size(); i++) { intersection.retainAll(groupingSets.get(i)); } @@ -150,6 +150,6 @@ public Set getCommonGroupingColumns() @Override public PlanNode replaceChildren(List newChildren) { - return new GroupIdNode(getId(), Iterables.getOnlyElement(newChildren), groupingSets, groupingColumns, aggregationArguments, groupIdSymbol); + return new GroupIdNode(getId(), Iterables.getOnlyElement(newChildren), groupingSets, groupingColumns, aggregationArguments, groupIdVariable); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java index 69c7c96c3bad1..d8ae7d00cf12e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -35,8 +35,8 @@ public class IndexJoinNode private final PlanNode probeSource; private final PlanNode indexSource; private final List criteria; - private final Optional probeHashSymbol; - private final Optional indexHashSymbol; + private final Optional probeHashVariable; + private final Optional indexHashVariable; @JsonCreator public IndexJoinNode( @@ -45,16 +45,16 @@ public IndexJoinNode( @JsonProperty("probeSource") PlanNode probeSource, @JsonProperty("indexSource") PlanNode indexSource, @JsonProperty("criteria") List criteria, - @JsonProperty("probeHashSymbol") Optional probeHashSymbol, - @JsonProperty("indexHashSymbol") Optional indexHashSymbol) + @JsonProperty("probeHashVariable") Optional probeHashVariable, + @JsonProperty("indexHashVariable") Optional indexHashVariable) { super(id); this.type = requireNonNull(type, "type is null"); this.probeSource = requireNonNull(probeSource, "probeSource is null"); this.indexSource = requireNonNull(indexSource, "indexSource is null"); this.criteria = ImmutableList.copyOf(requireNonNull(criteria, "criteria is null")); - this.probeHashSymbol = requireNonNull(probeHashSymbol, "probeHashSymbol is null"); - this.indexHashSymbol = requireNonNull(indexHashSymbol, "indexHashSymbol is null"); + this.probeHashVariable = requireNonNull(probeHashVariable, "probeHashVariable is null"); + this.indexHashVariable = requireNonNull(indexHashVariable, "indexHashVariable is null"); } public enum Type @@ -75,40 +75,40 @@ public String getJoinLabel() } } - @JsonProperty("type") + @JsonProperty public Type getType() { return type; } - @JsonProperty("probeSource") + @JsonProperty public PlanNode getProbeSource() { return probeSource; } - @JsonProperty("indexSource") + @JsonProperty public PlanNode getIndexSource() { return indexSource; } - @JsonProperty("criteria") + @JsonProperty public List getCriteria() { return criteria; } - @JsonProperty("probeHashSymbol") - public Optional getProbeHashSymbol() + @JsonProperty + public Optional getProbeHashVariable() { - return probeHashSymbol; + return probeHashVariable; } - @JsonProperty("indexHashSymbol") - public Optional getIndexHashSymbol() + @JsonProperty + public Optional getIndexHashVariable() { - return indexHashSymbol; + return indexHashVariable; } @Override @@ -118,11 +118,11 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.builder() - .addAll(probeSource.getOutputSymbols()) - .addAll(indexSource.getOutputSymbols()) + return ImmutableList.builder() + .addAll(probeSource.getOutputVariables()) + .addAll(indexSource.getOutputVariables()) .build(); } @@ -136,29 +136,29 @@ public R accept(InternalPlanVisitor visitor, C context) public PlanNode replaceChildren(List newChildren) { checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes"); - return new IndexJoinNode(getId(), type, newChildren.get(0), newChildren.get(1), criteria, probeHashSymbol, indexHashSymbol); + return new IndexJoinNode(getId(), type, newChildren.get(0), newChildren.get(1), criteria, probeHashVariable, indexHashVariable); } public static class EquiJoinClause { - private final Symbol probe; - private final Symbol index; + private final VariableReferenceExpression probe; + private final VariableReferenceExpression index; @JsonCreator - public EquiJoinClause(@JsonProperty("probe") Symbol probe, @JsonProperty("index") Symbol index) + public EquiJoinClause(@JsonProperty("probe") VariableReferenceExpression probe, @JsonProperty("index") VariableReferenceExpression index) { this.probe = requireNonNull(probe, "probe is null"); this.index = requireNonNull(index, "index is null"); } @JsonProperty("probe") - public Symbol getProbe() + public VariableReferenceExpression getProbe() { return probe; } @JsonProperty("index") - public Symbol getIndex() + public VariableReferenceExpression getIndex() { return index; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java index f8e17450c245b..c040723c145f5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java @@ -18,7 +18,7 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -37,9 +37,9 @@ public class IndexSourceNode { private final IndexHandle indexHandle; private final TableHandle tableHandle; - private final Set lookupSymbols; - private final List outputSymbols; - private final Map assignments; // symbol -> column + private final Set lookupVariables; + private final List outputVariables; + private final Map assignments; // symbol -> column private final TupleDomain currentConstraint; // constraint over the input data the operator will guarantee @JsonCreator @@ -47,22 +47,22 @@ public IndexSourceNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("indexHandle") IndexHandle indexHandle, @JsonProperty("tableHandle") TableHandle tableHandle, - @JsonProperty("lookupSymbols") Set lookupSymbols, - @JsonProperty("outputSymbols") List outputSymbols, - @JsonProperty("assignments") Map assignments, + @JsonProperty("lookupVariables") Set lookupVariables, + @JsonProperty("outputVariables") List outputVariables, + @JsonProperty("assignments") Map assignments, @JsonProperty("currentConstraint") TupleDomain currentConstraint) { super(id); this.indexHandle = requireNonNull(indexHandle, "indexHandle is null"); this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); - this.lookupSymbols = ImmutableSet.copyOf(requireNonNull(lookupSymbols, "lookupSymbols is null")); - this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputSymbols, "outputSymbols is null")); + this.lookupVariables = ImmutableSet.copyOf(requireNonNull(lookupVariables, "lookupVariables is null")); + this.outputVariables = ImmutableList.copyOf(requireNonNull(outputVariables, "outputVariables is null")); this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); this.currentConstraint = requireNonNull(currentConstraint, "effectiveTupleDomain is null"); - checkArgument(!lookupSymbols.isEmpty(), "lookupSymbols is empty"); - checkArgument(!outputSymbols.isEmpty(), "outputSymbols is empty"); - checkArgument(assignments.keySet().containsAll(lookupSymbols), "Assignments do not include all lookup symbols"); - checkArgument(outputSymbols.containsAll(lookupSymbols), "Lookup symbols need to be part of the output symbols"); + checkArgument(!lookupVariables.isEmpty(), "lookupVariables is empty"); + checkArgument(!outputVariables.isEmpty(), "outputVariables is empty"); + checkArgument(assignments.keySet().containsAll(lookupVariables), "Assignments do not include all lookup variables"); + checkArgument(outputVariables.containsAll(lookupVariables), "Lookup variables need to be part of the output variables"); } @JsonProperty @@ -78,20 +78,20 @@ public TableHandle getTableHandle() } @JsonProperty - public Set getLookupSymbols() + public Set getLookupVariables() { - return lookupSymbols; + return lookupVariables; } @Override @JsonProperty - public List getOutputSymbols() + public List getOutputVariables() { - return outputSymbols; + return outputVariables; } @JsonProperty - public Map getAssignments() + public Map getAssignments() { return assignments; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IntersectNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IntersectNode.java index da92c2632e75a..5a55f4a03fbec 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IntersectNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IntersectNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ListMultimap; @@ -31,10 +31,9 @@ public class IntersectNode public IntersectNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("sources") List sources, - @JsonProperty("outputToInputs") ListMultimap outputToInputs, - @JsonProperty("outputs") List outputs) + @JsonProperty("outputToInputs") ListMultimap outputToInputs) { - super(id, sources, outputToInputs, outputs); + super(id, sources, outputToInputs); } @Override @@ -46,6 +45,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new IntersectNode(getId(), newChildren, getSymbolMapping(), getOutputSymbols()); + return new IntersectNode(getId(), newChildren, getVariableMapping()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java index 1494dbd11de19..500e2330870d5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java @@ -16,8 +16,8 @@ import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.SortExpressionContext; -import com.facebook.presto.sql.planner.Symbol; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -52,10 +52,10 @@ public class JoinNode private final PlanNode left; private final PlanNode right; private final List criteria; - private final List outputSymbols; + private final List outputVariables; private final Optional filter; - private final Optional leftHashSymbol; - private final Optional rightHashSymbol; + private final Optional leftHashVariable; + private final Optional rightHashVariable; private final Optional distributionType; @JsonCreator @@ -64,10 +64,10 @@ public JoinNode(@JsonProperty("id") PlanNodeId id, @JsonProperty("left") PlanNode left, @JsonProperty("right") PlanNode right, @JsonProperty("criteria") List criteria, - @JsonProperty("outputSymbols") List outputSymbols, + @JsonProperty("outputVariables") List outputVariables, @JsonProperty("filter") Optional filter, - @JsonProperty("leftHashSymbol") Optional leftHashSymbol, - @JsonProperty("rightHashSymbol") Optional rightHashSymbol, + @JsonProperty("leftHashVariable") Optional leftHashVariable, + @JsonProperty("rightHashVariable") Optional rightHashVariable, @JsonProperty("distributionType") Optional distributionType) { super(id); @@ -75,31 +75,31 @@ public JoinNode(@JsonProperty("id") PlanNodeId id, requireNonNull(left, "left is null"); requireNonNull(right, "right is null"); requireNonNull(criteria, "criteria is null"); - requireNonNull(outputSymbols, "outputSymbols is null"); + requireNonNull(outputVariables, "outputVariables is null"); requireNonNull(filter, "filter is null"); - requireNonNull(leftHashSymbol, "leftHashSymbol is null"); - requireNonNull(rightHashSymbol, "rightHashSymbol is null"); + requireNonNull(leftHashVariable, "leftHashVariable is null"); + requireNonNull(rightHashVariable, "rightHashVariable is null"); requireNonNull(distributionType, "distributionType is null"); this.type = type; this.left = left; this.right = right; this.criteria = ImmutableList.copyOf(criteria); - this.outputSymbols = ImmutableList.copyOf(outputSymbols); + this.outputVariables = ImmutableList.copyOf(outputVariables); this.filter = filter; - this.leftHashSymbol = leftHashSymbol; - this.rightHashSymbol = rightHashSymbol; + this.leftHashVariable = leftHashVariable; + this.rightHashVariable = rightHashVariable; this.distributionType = distributionType; - Set inputSymbols = ImmutableSet.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + Set inputVariables = ImmutableSet.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(); - checkArgument(new HashSet<>(inputSymbols).containsAll(outputSymbols), "Left and right join inputs do not contain all output symbols"); - checkArgument(!isCrossJoin() || inputSymbols.size() == outputSymbols.size(), "Cross join does not support output symbols pruning or reordering"); + checkArgument(new HashSet<>(inputVariables).containsAll(outputVariables), "Left and right join inputs do not contain all output variables"); + checkArgument(!isCrossJoin() || inputVariables.size() == outputVariables.size(), "Cross join does not support output variables pruning or reordering"); - checkArgument(!(criteria.isEmpty() && leftHashSymbol.isPresent()), "Left hash symbol is only valid in an equijoin"); - checkArgument(!(criteria.isEmpty() && rightHashSymbol.isPresent()), "Right hash symbol is only valid in an equijoin"); + checkArgument(!(criteria.isEmpty() && leftHashVariable.isPresent()), "Left hash variable is only valid in an equijoin"); + checkArgument(!(criteria.isEmpty() && rightHashVariable.isPresent()), "Right hash variable is only valid in an equijoin"); if (distributionType.isPresent()) { // The implementation of full outer join only works if the data is hash partitioned. @@ -125,10 +125,10 @@ public JoinNode flipChildren() right, left, flipJoinCriteria(criteria), - flipOutputSymbols(getOutputSymbols(), left, right), + flipOutputVariables(getOutputVariables(), left, right), filter, - rightHashSymbol, - leftHashSymbol, + rightHashVariable, + leftHashVariable, distributionType); } @@ -155,17 +155,17 @@ private static List flipJoinCriteria(List joinCr .collect(toImmutableList()); } - private static List flipOutputSymbols(List outputSymbols, PlanNode left, PlanNode right) + private static List flipOutputVariables(List outputVariables, PlanNode left, PlanNode right) { - List leftSymbols = outputSymbols.stream() - .filter(symbol -> left.getOutputSymbols().contains(symbol)) + List leftVariables = outputVariables.stream() + .filter(variable -> left.getOutputVariables().contains(variable)) .collect(Collectors.toList()); - List rightSymbols = outputSymbols.stream() - .filter(symbol -> right.getOutputSymbols().contains(symbol)) + List rightVariables = outputVariables.stream() + .filter(variable -> right.getOutputVariables().contains(variable)) .collect(Collectors.toList()); - return ImmutableList.builder() - .addAll(rightSymbols) - .addAll(leftSymbols) + return ImmutableList.builder() + .addAll(rightVariables) + .addAll(leftVariables) .build(); } @@ -207,31 +207,31 @@ public boolean mustReplicate(List criteria) } } - @JsonProperty("type") + @JsonProperty public Type getType() { return type; } - @JsonProperty("left") + @JsonProperty public PlanNode getLeft() { return left; } - @JsonProperty("right") + @JsonProperty public PlanNode getRight() { return right; } - @JsonProperty("criteria") + @JsonProperty public List getCriteria() { return criteria; } - @JsonProperty("filter") + @JsonProperty public Optional getFilter() { return filter; @@ -240,19 +240,19 @@ public Optional getFilter() public Optional getSortExpressionContext(FunctionManager functionManager) { return filter - .flatMap(filter -> extractSortExpression(ImmutableSet.copyOf(right.getOutputSymbols()), filter, functionManager)); + .flatMap(filter -> extractSortExpression(ImmutableSet.copyOf(right.getOutputVariables()), filter, functionManager)); } - @JsonProperty("leftHashSymbol") - public Optional getLeftHashSymbol() + @JsonProperty + public Optional getLeftHashVariable() { - return leftHashSymbol; + return leftHashVariable; } - @JsonProperty("rightHashSymbol") - public Optional getRightHashSymbol() + @JsonProperty + public Optional getRightHashVariable() { - return rightHashSymbol; + return rightHashVariable; } @Override @@ -262,13 +262,13 @@ public List getSources() } @Override - @JsonProperty("outputSymbols") - public List getOutputSymbols() + @JsonProperty + public List getOutputVariables() { - return outputSymbols; + return outputVariables; } - @JsonProperty("distributionType") + @JsonProperty public Optional getDistributionType() { return distributionType; @@ -284,12 +284,12 @@ public R accept(InternalPlanVisitor visitor, C context) public PlanNode replaceChildren(List newChildren) { checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes"); - return new JoinNode(getId(), type, newChildren.get(0), newChildren.get(1), criteria, outputSymbols, filter, leftHashSymbol, rightHashSymbol, distributionType); + return new JoinNode(getId(), type, newChildren.get(0), newChildren.get(1), criteria, outputVariables, filter, leftHashVariable, rightHashVariable, distributionType); } public JoinNode withDistributionType(DistributionType distributionType) { - return new JoinNode(getId(), type, left, right, criteria, outputSymbols, filter, leftHashSymbol, rightHashSymbol, Optional.of(distributionType)); + return new JoinNode(getId(), type, left, right, criteria, outputVariables, filter, leftHashVariable, rightHashVariable, Optional.of(distributionType)); } public boolean isCrossJoin() @@ -299,24 +299,24 @@ public boolean isCrossJoin() public static class EquiJoinClause { - private final Symbol left; - private final Symbol right; + private final VariableReferenceExpression left; + private final VariableReferenceExpression right; @JsonCreator - public EquiJoinClause(@JsonProperty("left") Symbol left, @JsonProperty("right") Symbol right) + public EquiJoinClause(@JsonProperty("left") VariableReferenceExpression left, @JsonProperty("right") VariableReferenceExpression right) { this.left = requireNonNull(left, "left is null"); this.right = requireNonNull(right, "right is null"); } - @JsonProperty("left") - public Symbol getLeft() + @JsonProperty + public VariableReferenceExpression getLeft() { return left; } - @JsonProperty("right") - public Symbol getRight() + @JsonProperty + public VariableReferenceExpression getRight() { return right; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java index 03a82843efcc6..a730f916ee076 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -59,9 +59,9 @@ public JoinNode.Type toJoinNodeType() private final PlanNode subquery; /** - * Correlation symbols, returned from input (outer plan) used in subquery (inner plan) + * Correlation variables, returned from input (outer plan) used in subquery (inner plan) */ - private final List correlation; + private final List correlation; private final Type type; /** @@ -74,7 +74,7 @@ public LateralJoinNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("input") PlanNode input, @JsonProperty("subquery") PlanNode subquery, - @JsonProperty("correlation") List correlation, + @JsonProperty("correlation") List correlation, @JsonProperty("type") Type type, @JsonProperty("originSubqueryError") String originSubqueryError) { @@ -84,7 +84,7 @@ public LateralJoinNode( requireNonNull(correlation, "correlation is null"); requireNonNull(originSubqueryError, "originSubqueryError is null"); - checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation"); + checkArgument(input.getOutputVariables().containsAll(correlation), "Input does not contain symbols from correlation"); this.input = input; this.subquery = subquery; @@ -93,31 +93,31 @@ public LateralJoinNode( this.originSubqueryError = originSubqueryError; } - @JsonProperty("input") + @JsonProperty public PlanNode getInput() { return input; } - @JsonProperty("subquery") + @JsonProperty public PlanNode getSubquery() { return subquery; } - @JsonProperty("correlation") - public List getCorrelation() + @JsonProperty + public List getCorrelation() { return correlation; } - @JsonProperty("type") + @JsonProperty public Type getType() { return type; } - @JsonProperty("originSubqueryError") + @JsonProperty public String getOriginSubqueryError() { return originSubqueryError; @@ -130,12 +130,11 @@ public List getSources() } @Override - @JsonProperty("outputSymbols") - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.builder() - .addAll(input.getOutputSymbols()) - .addAll(subquery.getOutputSymbols()) + return ImmutableList.builder() + .addAll(input.getOutputVariables()) + .addAll(subquery.getOutputVariables()) .build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LimitNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LimitNode.java index d9cd897525fc5..2da17933c4ceb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LimitNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LimitNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -77,9 +77,9 @@ public boolean isPartial() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return source.getOutputSymbols(); + return source.getOutputVariables(); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MarkDistinctNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MarkDistinctNode.java index 22bc20211bc62..483a32ff73190 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MarkDistinctNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MarkDistinctNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -33,33 +33,33 @@ public class MarkDistinctNode extends InternalPlanNode { private final PlanNode source; - private final Symbol markerSymbol; + private final VariableReferenceExpression markerVariable; - private final Optional hashSymbol; - private final List distinctSymbols; + private final Optional hashVariable; + private final List distinctVariables; @JsonCreator public MarkDistinctNode(@JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("markerSymbol") Symbol markerSymbol, - @JsonProperty("distinctSymbols") List distinctSymbols, - @JsonProperty("hashSymbol") Optional hashSymbol) + @JsonProperty("markerVariable") VariableReferenceExpression markerVariable, + @JsonProperty("distinctVariables") List distinctVariables, + @JsonProperty("hashVariable") Optional hashVariable) { super(id); this.source = source; - this.markerSymbol = markerSymbol; - this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); - requireNonNull(distinctSymbols, "distinctSymbols is null"); - checkArgument(!distinctSymbols.isEmpty(), "distinctSymbols cannot be empty"); - this.distinctSymbols = ImmutableList.copyOf(distinctSymbols); + this.markerVariable = markerVariable; + this.hashVariable = requireNonNull(hashVariable, "hashVariable is null"); + requireNonNull(distinctVariables, "distinctVariables is null"); + checkArgument(!distinctVariables.isEmpty(), "distinctVariables cannot be empty"); + this.distinctVariables = ImmutableList.copyOf(distinctVariables); } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.builder() - .addAll(source.getOutputSymbols()) - .add(markerSymbol) + return ImmutableList.builder() + .addAll(source.getOutputVariables()) + .add(markerVariable) .build(); } @@ -76,21 +76,21 @@ public PlanNode getSource() } @JsonProperty - public Symbol getMarkerSymbol() + public VariableReferenceExpression getMarkerVariable() { - return markerSymbol; + return markerVariable; } @JsonProperty - public List getDistinctSymbols() + public List getDistinctVariables() { - return distinctSymbols; + return distinctVariables; } @JsonProperty - public Optional getHashSymbol() + public Optional getHashVariable() { - return hashSymbol; + return hashVariable; } @Override @@ -102,6 +102,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new MarkDistinctNode(getId(), Iterables.getOnlyElement(newChildren), markerSymbol, distinctSymbols, hashSymbol); + return new MarkDistinctNode(getId(), Iterables.getOnlyElement(newChildren), markerVariable, distinctVariables, hashVariable); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MetadataDeleteNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MetadataDeleteNode.java index 54b21ff4d16d5..f26e354669214 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MetadataDeleteNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MetadataDeleteNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.TableWriterNode.DeleteHandle; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -31,13 +31,13 @@ public class MetadataDeleteNode extends InternalPlanNode { private final DeleteHandle target; - private final Symbol output; + private final VariableReferenceExpression output; @JsonCreator public MetadataDeleteNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("target") DeleteHandle target, - @JsonProperty("output") Symbol output) + @JsonProperty("output") VariableReferenceExpression output) { super(id); @@ -52,13 +52,13 @@ public DeleteHandle getTarget() } @JsonProperty - public Symbol getOutput() + public VariableReferenceExpression getOutput() { return output; } @Override - public List getOutputSymbols() + public List getOutputVariables() { return ImmutableList.of(output); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/OutputNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/OutputNode.java index 5c61486954d6b..a56e011768bcd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/OutputNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/OutputNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; @@ -33,23 +33,23 @@ public class OutputNode { private final PlanNode source; private final List columnNames; - private final List outputs; // column name = symbol + private final List outputVariables; // column name = variable.name @JsonCreator public OutputNode(@JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("columns") List columnNames, - @JsonProperty("outputs") List outputs) + @JsonProperty("columnNames") List columnNames, + @JsonProperty("outputVariables") List outputVariables) { super(id); requireNonNull(source, "source is null"); requireNonNull(columnNames, "columnNames is null"); - Preconditions.checkArgument(columnNames.size() == outputs.size(), "columnNames and assignments sizes don't match"); + Preconditions.checkArgument(columnNames.size() == outputVariables.size(), "columnNames and assignments sizes don't match"); this.source = source; this.columnNames = columnNames; - this.outputs = ImmutableList.copyOf(outputs); + this.outputVariables = ImmutableList.copyOf(outputVariables); } @Override @@ -59,13 +59,13 @@ public List getSources() } @Override - @JsonProperty("outputs") - public List getOutputSymbols() + @JsonProperty + public List getOutputVariables() { - return outputs; + return outputVariables; } - @JsonProperty("columns") + @JsonProperty public List getColumnNames() { return columnNames; @@ -86,6 +86,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new OutputNode(getId(), Iterables.getOnlyElement(newChildren), columnNames, outputs); + return new OutputNode(getId(), Iterables.getOnlyElement(newChildren), columnNames, outputVariables); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java index abf88c52fba04..64baa07aca6e3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java @@ -16,7 +16,7 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.matching.Property; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import java.util.List; import java.util.Optional; @@ -178,7 +178,7 @@ public static Property> sources() public static class Aggregation { - public static Property> groupingColumns() + public static Property> groupingColumns() { return property("groupingKeys", AggregationNode::getGroupingKeys); } @@ -191,7 +191,7 @@ public static Property step() public static class Apply { - public static Property> correlation() + public static Property> correlation() { return property("correlation", ApplyNode::getCorrelation); } @@ -207,7 +207,7 @@ public static Property type() public static class LateralJoin { - public static Property> correlation() + public static Property> correlation() { return property("correlation", LateralJoinNode::getCorrelation); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanNode.java index 1bb1d5e6f5dc2..e3952c0dc77f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; @@ -52,7 +52,7 @@ public PlanNodeId getId() * The output from the upstream PlanNodes. * It should serve as the input for the current PlanNode. */ - public abstract List getOutputSymbols(); + public abstract List getOutputVariables(); /** * Alter the upstream PlanNodes of the current PlanNode. diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ProjectNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ProjectNode.java index d7ac7aa9cb13a..32560b8c4b30f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ProjectNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ProjectNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -49,7 +49,7 @@ public ProjectNode(@JsonProperty("id") PlanNodeId id, } @Override - public List getOutputSymbols() + public List getOutputVariables() { return assignments.getOutputs(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java index 6b5ad200260f5..795484f8b593f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -33,7 +33,7 @@ public class RemoteSourceNode extends InternalPlanNode { private final List sourceFragmentIds; - private final List outputs; + private final List outputVariables; private final Optional orderingScheme; private final ExchangeNode.Type exchangeType; // This is needed to "unfragment" to compute stats correctly. @@ -41,23 +41,26 @@ public class RemoteSourceNode public RemoteSourceNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("sourceFragmentIds") List sourceFragmentIds, - @JsonProperty("outputs") List outputs, + @JsonProperty("outputVariables") List outputVariables, @JsonProperty("orderingScheme") Optional orderingScheme, @JsonProperty("exchangeType") ExchangeNode.Type exchangeType) { super(id); - requireNonNull(outputs, "outputs is null"); - this.sourceFragmentIds = sourceFragmentIds; - this.outputs = ImmutableList.copyOf(outputs); + this.outputVariables = ImmutableList.copyOf(requireNonNull(outputVariables, "outputVariables is null")); this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); this.exchangeType = requireNonNull(exchangeType, "exchangeType is null"); } - public RemoteSourceNode(PlanNodeId id, PlanFragmentId sourceFragmentId, List outputs, Optional orderingScheme, ExchangeNode.Type exchangeType) + public RemoteSourceNode( + PlanNodeId id, + PlanFragmentId sourceFragmentId, + List outputVariables, + Optional orderingScheme, + ExchangeNode.Type exchangeType) { - this(id, ImmutableList.of(sourceFragmentId), outputs, orderingScheme, exchangeType); + this(id, ImmutableList.of(sourceFragmentId), outputVariables, orderingScheme, exchangeType); } @Override @@ -67,25 +70,25 @@ public List getSources() } @Override - @JsonProperty("outputs") - public List getOutputSymbols() + @JsonProperty + public List getOutputVariables() { - return outputs; + return outputVariables; } - @JsonProperty("sourceFragmentIds") + @JsonProperty public List getSourceFragmentIds() { return sourceFragmentIds; } - @JsonProperty("orderingScheme") + @JsonProperty public Optional getOrderingScheme() { return orderingScheme; } - @JsonProperty("exchangeType") + @JsonProperty public ExchangeNode.Type getExchangeType() { return exchangeType; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java index 85655811c8617..90977919ca2b7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -25,7 +25,6 @@ import java.util.List; import java.util.Optional; -import static com.google.common.collect.Iterables.concat; import static java.util.Objects.requireNonNull; @Immutable @@ -33,33 +32,33 @@ public final class RowNumberNode extends InternalPlanNode { private final PlanNode source; - private final List partitionBy; + private final List partitionBy; private final Optional maxRowCountPerPartition; - private final Symbol rowNumberSymbol; - private final Optional hashSymbol; + private final VariableReferenceExpression rowNumberVariable; + private final Optional hashVariable; @JsonCreator public RowNumberNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("partitionBy") List partitionBy, - @JsonProperty("rowNumberSymbol") Symbol rowNumberSymbol, + @JsonProperty("partitionBy") List partitionBy, + @JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable, @JsonProperty("maxRowCountPerPartition") Optional maxRowCountPerPartition, - @JsonProperty("hashSymbol") Optional hashSymbol) + @JsonProperty("hashVariable") Optional hashVariable) { super(id); requireNonNull(source, "source is null"); requireNonNull(partitionBy, "partitionBy is null"); - requireNonNull(rowNumberSymbol, "rowNumberSymbol is null"); + requireNonNull(rowNumberVariable, "rowNumberVariable is null"); requireNonNull(maxRowCountPerPartition, "maxRowCountPerPartition is null"); - requireNonNull(hashSymbol, "hashSymbol is null"); + requireNonNull(hashVariable, "hashVariable is null"); this.source = source; this.partitionBy = ImmutableList.copyOf(partitionBy); - this.rowNumberSymbol = rowNumberSymbol; + this.rowNumberVariable = rowNumberVariable; this.maxRowCountPerPartition = maxRowCountPerPartition; - this.hashSymbol = hashSymbol; + this.hashVariable = hashVariable; } @Override @@ -69,9 +68,12 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.copyOf(concat(source.getOutputSymbols(), ImmutableList.of(rowNumberSymbol))); + return ImmutableList.builder() + .addAll(source.getOutputVariables()) + .add(rowNumberVariable) + .build(); } @JsonProperty @@ -81,15 +83,15 @@ public PlanNode getSource() } @JsonProperty - public List getPartitionBy() + public List getPartitionBy() { return partitionBy; } @JsonProperty - public Symbol getRowNumberSymbol() + public VariableReferenceExpression getRowNumberVariable() { - return rowNumberSymbol; + return rowNumberVariable; } @JsonProperty @@ -99,9 +101,9 @@ public Optional getMaxRowCountPerPartition() } @JsonProperty - public Optional getHashSymbol() + public Optional getHashVariable() { - return hashSymbol; + return hashVariable; } @Override @@ -113,6 +115,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new RowNumberNode(getId(), Iterables.getOnlyElement(newChildren), partitionBy, rowNumberSymbol, maxRowCountPerPartition, hashSymbol); + return new RowNumberNode(getId(), Iterables.getOnlyElement(newChildren), partitionBy, rowNumberVariable, maxRowCountPerPartition, hashVariable); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java index c84ce97d39c0b..a1f900583b4f3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -83,9 +83,9 @@ public Type getSampleType() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return source.getOutputSymbols(); + return source.getOutputVariables(); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SemiJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SemiJoinNode.java index fd28eb4c6b7e9..4dd196e9a3af7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SemiJoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SemiJoinNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -33,36 +33,36 @@ public class SemiJoinNode { private final PlanNode source; private final PlanNode filteringSource; - private final Symbol sourceJoinSymbol; - private final Symbol filteringSourceJoinSymbol; - private final Symbol semiJoinOutput; - private final Optional sourceHashSymbol; - private final Optional filteringSourceHashSymbol; + private final VariableReferenceExpression sourceJoinVariable; + private final VariableReferenceExpression filteringSourceJoinVariable; + private final VariableReferenceExpression semiJoinOutput; + private final Optional sourceHashVariable; + private final Optional filteringSourceHashVariable; private final Optional distributionType; @JsonCreator public SemiJoinNode(@JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("filteringSource") PlanNode filteringSource, - @JsonProperty("sourceJoinSymbol") Symbol sourceJoinSymbol, - @JsonProperty("filteringSourceJoinSymbol") Symbol filteringSourceJoinSymbol, - @JsonProperty("semiJoinOutput") Symbol semiJoinOutput, - @JsonProperty("sourceHashSymbol") Optional sourceHashSymbol, - @JsonProperty("filteringSourceHashSymbol") Optional filteringSourceHashSymbol, + @JsonProperty("sourceJoinVariable") VariableReferenceExpression sourceJoinVariable, + @JsonProperty("filteringSourceJoinVariable") VariableReferenceExpression filteringSourceJoinVariable, + @JsonProperty("semiJoinOutput") VariableReferenceExpression semiJoinOutput, + @JsonProperty("sourceHashVariable") Optional sourceHashVariable, + @JsonProperty("filteringSourceHashVariable") Optional filteringSourceHashVariable, @JsonProperty("distributionType") Optional distributionType) { super(id); this.source = requireNonNull(source, "source is null"); this.filteringSource = requireNonNull(filteringSource, "filteringSource is null"); - this.sourceJoinSymbol = requireNonNull(sourceJoinSymbol, "sourceJoinSymbol is null"); - this.filteringSourceJoinSymbol = requireNonNull(filteringSourceJoinSymbol, "filteringSourceJoinSymbol is null"); + this.sourceJoinVariable = requireNonNull(sourceJoinVariable, "sourceJoinVariable is null"); + this.filteringSourceJoinVariable = requireNonNull(filteringSourceJoinVariable, "filteringSourceJoinVariable is null"); this.semiJoinOutput = requireNonNull(semiJoinOutput, "semiJoinOutput is null"); - this.sourceHashSymbol = requireNonNull(sourceHashSymbol, "sourceHashSymbol is null"); - this.filteringSourceHashSymbol = requireNonNull(filteringSourceHashSymbol, "filteringSourceHashSymbol is null"); + this.sourceHashVariable = requireNonNull(sourceHashVariable, "sourceHashVariable is null"); + this.filteringSourceHashVariable = requireNonNull(filteringSourceHashVariable, "filteringSourceHashVariable is null"); this.distributionType = requireNonNull(distributionType, "distributionType is null"); - checkArgument(source.getOutputSymbols().contains(sourceJoinSymbol), "Source does not contain join symbol"); - checkArgument(filteringSource.getOutputSymbols().contains(filteringSourceJoinSymbol), "Filtering source does not contain filtering join symbol"); + checkArgument(source.getOutputVariables().contains(sourceJoinVariable), "Source does not contain join symbol"); + checkArgument(filteringSource.getOutputVariables().contains(filteringSourceJoinVariable), "Filtering source does not contain filtering join symbol"); } public enum DistributionType @@ -71,49 +71,49 @@ public enum DistributionType REPLICATED } - @JsonProperty("source") + @JsonProperty public PlanNode getSource() { return source; } - @JsonProperty("filteringSource") + @JsonProperty public PlanNode getFilteringSource() { return filteringSource; } - @JsonProperty("sourceJoinSymbol") - public Symbol getSourceJoinSymbol() + @JsonProperty + public VariableReferenceExpression getSourceJoinVariable() { - return sourceJoinSymbol; + return sourceJoinVariable; } - @JsonProperty("filteringSourceJoinSymbol") - public Symbol getFilteringSourceJoinSymbol() + @JsonProperty + public VariableReferenceExpression getFilteringSourceJoinVariable() { - return filteringSourceJoinSymbol; + return filteringSourceJoinVariable; } - @JsonProperty("semiJoinOutput") - public Symbol getSemiJoinOutput() + @JsonProperty + public VariableReferenceExpression getSemiJoinOutput() { return semiJoinOutput; } - @JsonProperty("sourceHashSymbol") - public Optional getSourceHashSymbol() + @JsonProperty + public Optional getSourceHashVariable() { - return sourceHashSymbol; + return sourceHashVariable; } - @JsonProperty("filteringSourceHashSymbol") - public Optional getFilteringSourceHashSymbol() + @JsonProperty + public Optional getFilteringSourceHashVariable() { - return filteringSourceHashSymbol; + return filteringSourceHashVariable; } - @JsonProperty("distributionType") + @JsonProperty public Optional getDistributionType() { return distributionType; @@ -126,10 +126,10 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.builder() - .addAll(source.getOutputSymbols()) + return ImmutableList.builder() + .addAll(source.getOutputVariables()) .add(semiJoinOutput) .build(); } @@ -148,11 +148,11 @@ public PlanNode replaceChildren(List newChildren) getId(), newChildren.get(0), newChildren.get(1), - sourceJoinSymbol, - filteringSourceJoinSymbol, + sourceJoinVariable, + filteringSourceJoinVariable, semiJoinOutput, - sourceHashSymbol, - filteringSourceHashSymbol, + sourceHashVariable, + filteringSourceHashVariable, distributionType); } @@ -162,11 +162,11 @@ public SemiJoinNode withDistributionType(DistributionType distributionType) getId(), source, filteringSource, - sourceJoinSymbol, - filteringSourceJoinSymbol, + sourceJoinVariable, + filteringSourceJoinVariable, semiJoinOutput, - sourceHashSymbol, - filteringSourceHashSymbol, + sourceHashVariable, + filteringSourceHashVariable, Optional.of(distributionType)); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SetOperationNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SetOperationNode.java index a86e40ac3f91c..b06d014d91c05 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SetOperationNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SetOperationNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.FluentIterable; @@ -40,88 +40,84 @@ public abstract class SetOperationNode extends InternalPlanNode { private final List sources; - private final ListMultimap outputToInputs; - private final List outputs; + private final ListMultimap outputToInputs; + private final List outputVariables; @JsonCreator protected SetOperationNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("sources") List sources, - @JsonProperty("outputToInputs") ListMultimap outputToInputs, - @JsonProperty("outputs") List outputs) + @JsonProperty("outputToInputs") ListMultimap outputToInputs) { super(id); requireNonNull(sources, "sources is null"); checkArgument(!sources.isEmpty(), "Must have at least one source"); requireNonNull(outputToInputs, "outputToInputs is null"); - requireNonNull(outputs, "outputs is null"); this.sources = ImmutableList.copyOf(sources); this.outputToInputs = ImmutableListMultimap.copyOf(outputToInputs); - this.outputs = ImmutableList.copyOf(outputs); + this.outputVariables = ImmutableList.copyOf(outputToInputs.keySet()); - for (Collection inputs : this.outputToInputs.asMap().values()) { + for (Collection inputs : this.outputToInputs.asMap().values()) { checkArgument(inputs.size() == this.sources.size(), "Every source needs to map its symbols to an output %s operation symbol", this.getClass().getSimpleName()); } // Make sure each source positionally corresponds to their Symbol values in the Multimap for (int i = 0; i < sources.size(); i++) { - for (Collection expectedInputs : this.outputToInputs.asMap().values()) { - checkArgument(sources.get(i).getOutputSymbols().contains(Iterables.get(expectedInputs, i)), "Source does not provide required symbols"); + for (Collection expectedInputs : this.outputToInputs.asMap().values()) { + checkArgument(sources.get(i).getOutputVariables().contains(Iterables.get(expectedInputs, i)), "Source does not provide required symbols"); } } } @Override - @JsonProperty("sources") + @JsonProperty public List getSources() { return sources; } - @Override - @JsonProperty("outputs") - public List getOutputSymbols() + @JsonProperty + public ListMultimap getVariableMapping() { - return outputs; + return outputToInputs; } - @JsonProperty("outputToInputs") - public ListMultimap getSymbolMapping() + @Override + public List getOutputVariables() { - return outputToInputs; + return outputVariables; } - public List sourceOutputLayout(int sourceIndex) + public List sourceOutputLayout(int sourceIndex) { - // Make sure the sourceOutputLayout symbols are listed in the same order as the corresponding output symbols - return getOutputSymbols().stream() - .map(symbol -> outputToInputs.get(symbol).get(sourceIndex)) + // Make sure the sourceOutputSymbolLayout symbols are listed in the same order as the corresponding output symbols + return getOutputVariables().stream() + .map(variable -> outputToInputs.get(variable).get(sourceIndex)) .collect(toImmutableList()); } /** - * Returns the output to input symbol mapping for the given source channel + * Returns the output to input variable mapping for the given source channel */ - public Map sourceSymbolMap(int sourceIndex) + public Map sourceVariableMap(int sourceIndex) { - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (Map.Entry> entry : outputToInputs.asMap().entrySet()) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (Map.Entry> entry : outputToInputs.asMap().entrySet()) { builder.put(entry.getKey(), Iterables.get(entry.getValue(), sourceIndex)); } return builder.build(); } - /** * Returns the input to output symbol mapping for the given source channel. * A single input symbol can map to multiple output symbols, thus requiring a Multimap. */ - public Multimap outputSymbolMap(int sourceIndex) + public Multimap outputMap(int sourceIndex) { - return FluentIterable.from(getOutputSymbols()) - .toMap(outputSymbol -> outputToInputs.get(outputSymbol).get(sourceIndex)) + return FluentIterable.from(getOutputVariables()) + .toMap(output -> outputToInputs.get(output).get(sourceIndex)) .asMultimap() .inverse(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SortNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SortNode.java index 565276cdeac35..d8fbfb769ffd5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SortNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SortNode.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -58,9 +58,9 @@ public PlanNode getSource() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return source.getOutputSymbols(); + return source.getOutputVariables(); } @JsonProperty("orderingScheme") diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SpatialJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SpatialJoinNode.java index 57921b685383e..d3a04b7ada856 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SpatialJoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SpatialJoinNode.java @@ -15,7 +15,7 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -67,10 +67,10 @@ public static Type fromJoinNodeType(JoinNode.Type joinNodeType) private final Type type; private final PlanNode left; private final PlanNode right; - private final List outputSymbols; + private final List outputVariables; private final RowExpression filter; - private final Optional leftPartitionSymbol; - private final Optional rightPartitionSymbol; + private final Optional leftPartitionVariable; + private final Optional rightPartitionVariable; private final Optional kdbTree; private final DistributionType distributionType; @@ -86,10 +86,10 @@ public SpatialJoinNode( @JsonProperty("type") Type type, @JsonProperty("left") PlanNode left, @JsonProperty("right") PlanNode right, - @JsonProperty("outputSymbols") List outputSymbols, + @JsonProperty("outputVariables") List outputVariables, @JsonProperty("filter") RowExpression filter, - @JsonProperty("leftPartitionSymbol") Optional leftPartitionSymbol, - @JsonProperty("rightPartitionSymbol") Optional rightPartitionSymbol, + @JsonProperty("leftPartitionVariable") Optional leftPartitionVariable, + @JsonProperty("rightPartitionVariable") Optional rightPartitionVariable, @JsonProperty("kdbTree") Optional kdbTree) { super(id); @@ -97,66 +97,66 @@ public SpatialJoinNode( this.type = requireNonNull(type, "type is null"); this.left = requireNonNull(left, "left is null"); this.right = requireNonNull(right, "right is null"); - this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputSymbols, "outputSymbols is null")); + this.outputVariables = ImmutableList.copyOf(requireNonNull(outputVariables, "outputVariables is null")); this.filter = requireNonNull(filter, "filter is null"); - this.leftPartitionSymbol = requireNonNull(leftPartitionSymbol, "leftPartitionSymbol is null"); - this.rightPartitionSymbol = requireNonNull(rightPartitionSymbol, "rightPartitionSymbol is null"); + this.leftPartitionVariable = requireNonNull(leftPartitionVariable, "leftPartitionVariable is null"); + this.rightPartitionVariable = requireNonNull(rightPartitionVariable, "rightPartitionVariable is null"); this.kdbTree = requireNonNull(kdbTree, "kdbTree is null"); - Set inputSymbols = ImmutableSet.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + Set inputSymbols = ImmutableSet.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(); - checkArgument(inputSymbols.containsAll(outputSymbols), "Left and right join inputs do not contain all output symbols"); + checkArgument(inputSymbols.containsAll(outputVariables), "Left and right join inputs do not contain all output variables"); if (kdbTree.isPresent()) { - checkArgument(leftPartitionSymbol.isPresent(), "Left partition symbol is missing"); - checkArgument(rightPartitionSymbol.isPresent(), "Right partition symbol is missing"); - checkArgument(left.getOutputSymbols().contains(leftPartitionSymbol.get()), "Left join input does not contain left partition symbol"); - checkArgument(right.getOutputSymbols().contains(rightPartitionSymbol.get()), "Right join input does not contain right partition symbol"); + checkArgument(leftPartitionVariable.isPresent(), "Left partition variable is missing"); + checkArgument(rightPartitionVariable.isPresent(), "Right partition variable is missing"); + checkArgument(left.getOutputVariables().contains(leftPartitionVariable.get()), "Left join input does not contain left partition variable"); + checkArgument(right.getOutputVariables().contains(rightPartitionVariable.get()), "Right join input does not contain right partition variable"); this.distributionType = DistributionType.PARTITIONED; } else { - checkArgument(!leftPartitionSymbol.isPresent(), "KDB tree is missing"); - checkArgument(!rightPartitionSymbol.isPresent(), "KDB tree is missing"); + checkArgument(!leftPartitionVariable.isPresent(), "KDB tree is missing"); + checkArgument(!rightPartitionVariable.isPresent(), "KDB tree is missing"); this.distributionType = DistributionType.REPLICATED; } } - @JsonProperty("type") + @JsonProperty public Type getType() { return type; } - @JsonProperty("left") + @JsonProperty public PlanNode getLeft() { return left; } - @JsonProperty("right") + @JsonProperty public PlanNode getRight() { return right; } - @JsonProperty("filter") + @JsonProperty public RowExpression getFilter() { return filter; } - @JsonProperty("leftPartitionSymbol") - public Optional getLeftPartitionSymbol() + @JsonProperty + public Optional getLeftPartitionVariable() { - return leftPartitionSymbol; + return leftPartitionVariable; } - @JsonProperty("rightPartitionSymbol") - public Optional getRightPartitionSymbol() + @JsonProperty + public Optional getRightPartitionVariable() { - return rightPartitionSymbol; + return rightPartitionVariable; } @Override @@ -166,19 +166,19 @@ public List getSources() } @Override - @JsonProperty("outputSymbols") - public List getOutputSymbols() + @JsonProperty + public List getOutputVariables() { - return outputSymbols; + return outputVariables; } - @JsonProperty("distributionType") + @JsonProperty public DistributionType getDistributionType() { return distributionType; } - @JsonProperty("kdbTree") + @JsonProperty public Optional getKdbTree() { return kdbTree; @@ -194,6 +194,6 @@ public R accept(InternalPlanVisitor visitor, C context) public PlanNode replaceChildren(List newChildren) { checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes"); - return new SpatialJoinNode(getId(), type, newChildren.get(0), newChildren.get(1), outputSymbols, filter, leftPartitionSymbol, rightPartitionSymbol, kdbTree); + return new SpatialJoinNode(getId(), type, newChildren.get(0), newChildren.get(1), outputVariables, filter, leftPartitionVariable, rightPartitionVariable, kdbTree); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregations.java index dc403f735914e..df0a0060e39bc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregations.java @@ -16,9 +16,11 @@ import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; +import com.facebook.presto.sql.tree.SymbolReference; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -28,46 +30,55 @@ import java.util.Map; import java.util.Optional; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class StatisticAggregations { - private final Map aggregations; - private final List groupingSymbols; + private final Map aggregations; + private final List groupingVariables; @JsonCreator public StatisticAggregations( - @JsonProperty("aggregations") Map aggregations, - @JsonProperty("groupingSymbols") List groupingSymbols) + @JsonProperty("aggregations") Map aggregations, + @JsonProperty("groupingVariables") List groupingVariables) { this.aggregations = ImmutableMap.copyOf(requireNonNull(aggregations, "aggregations is null")); - this.groupingSymbols = ImmutableList.copyOf(requireNonNull(groupingSymbols, "groupingSymbols is null")); + this.groupingVariables = ImmutableList.copyOf(requireNonNull(groupingVariables, "groupingVariables is null")); } @JsonProperty - public Map getAggregations() + public Map getAggregations() { return aggregations; } @JsonProperty + public List getGroupingVariables() + { + return groupingVariables; + } + public List getGroupingSymbols() { - return groupingSymbols; + return groupingVariables.stream() + .map(VariableReferenceExpression::getName) + .map(Symbol::new) + .collect(toImmutableList()); } public Parts createPartialAggregations(SymbolAllocator symbolAllocator, FunctionManager functionManager) { - ImmutableMap.Builder partialAggregation = ImmutableMap.builder(); - ImmutableMap.Builder finalAggregation = ImmutableMap.builder(); - ImmutableMap.Builder mappings = ImmutableMap.builder(); - for (Map.Entry entry : aggregations.entrySet()) { + ImmutableMap.Builder partialAggregation = ImmutableMap.builder(); + ImmutableMap.Builder finalAggregation = ImmutableMap.builder(); + ImmutableMap.Builder mappings = ImmutableMap.builder(); + for (Map.Entry entry : aggregations.entrySet()) { Aggregation originalAggregation = entry.getValue(); FunctionHandle functionHandle = originalAggregation.getFunctionHandle(); InternalAggregationFunction function = functionManager.getAggregateFunctionImplementation(functionHandle); - Symbol partialSymbol = symbolAllocator.newSymbol(functionManager.getFunctionMetadata(functionHandle).getName(), function.getIntermediateType()); - mappings.put(entry.getKey(), partialSymbol); - partialAggregation.put(partialSymbol, new Aggregation( + VariableReferenceExpression partialVariable = symbolAllocator.newVariable(functionManager.getFunctionMetadata(functionHandle).getName(), function.getIntermediateType()); + mappings.put(entry.getKey(), partialVariable); + partialAggregation.put(partialVariable, new Aggregation( functionHandle, originalAggregation.getArguments(), originalAggregation.getFilter(), @@ -77,16 +88,16 @@ public Parts createPartialAggregations(SymbolAllocator symbolAllocator, Function finalAggregation.put(entry.getKey(), new Aggregation( functionHandle, - ImmutableList.of(partialSymbol.toSymbolReference()), + ImmutableList.of(new SymbolReference(partialVariable.getName())), Optional.empty(), Optional.empty(), false, Optional.empty())); } - groupingSymbols.forEach(symbol -> mappings.put(symbol, symbol)); + groupingVariables.forEach(symbol -> mappings.put(symbol, symbol)); return new Parts( - new StatisticAggregations(partialAggregation.build(), groupingSymbols), - new StatisticAggregations(finalAggregation.build(), groupingSymbols), + new StatisticAggregations(partialAggregation.build(), groupingVariables), + new StatisticAggregations(finalAggregation.build(), groupingVariables), mappings.build()); } @@ -94,9 +105,9 @@ public static class Parts { private final StatisticAggregations partialAggregation; private final StatisticAggregations finalAggregation; - private final Map mappings; + private final Map mappings; - public Parts(StatisticAggregations partialAggregation, StatisticAggregations finalAggregation, Map mappings) + public Parts(StatisticAggregations partialAggregation, StatisticAggregations finalAggregation, Map mappings) { this.partialAggregation = requireNonNull(partialAggregation, "partialAggregation is null"); this.finalAggregation = requireNonNull(finalAggregation, "finalAggregation is null"); @@ -113,7 +124,7 @@ public StatisticAggregations getFinalAggregation() return finalAggregation; } - public Map getMappings() + public Map getMappings() { return mappings; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticsWriterNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticsWriterNode.java index 7b6ac3fab2f63..5fe7c6164975b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticsWriterNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticsWriterNode.java @@ -16,7 +16,7 @@ import com.facebook.presto.metadata.AnalyzeTableHandle; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; @@ -32,24 +32,24 @@ public class StatisticsWriterNode extends InternalPlanNode { private final PlanNode source; - private final Symbol rowCountSymbol; + private final VariableReferenceExpression rowCountVariable; private final WriteStatisticsTarget target; private final boolean rowCountEnabled; - private final StatisticAggregationsDescriptor descriptor; + private final StatisticAggregationsDescriptor descriptor; @JsonCreator public StatisticsWriterNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("target") WriteStatisticsTarget target, - @JsonProperty("rowCountSymbol") Symbol rowCountSymbol, + @JsonProperty("rowCountVariable") VariableReferenceExpression rowCountVariable, @JsonProperty("rowCountEnabled") boolean rowCountEnabled, - @JsonProperty("descriptor") StatisticAggregationsDescriptor descriptor) + @JsonProperty("descriptor") StatisticAggregationsDescriptor descriptor) { super(id); this.source = requireNonNull(source, "source is null"); this.target = requireNonNull(target, "target is null"); - this.rowCountSymbol = requireNonNull(rowCountSymbol, "rowCountSymbol is null"); + this.rowCountVariable = requireNonNull(rowCountVariable, "rowCountVariable is null"); this.rowCountEnabled = rowCountEnabled; this.descriptor = requireNonNull(descriptor, "descriptor is null"); } @@ -67,15 +67,15 @@ public WriteStatisticsTarget getTarget() } @JsonProperty - public StatisticAggregationsDescriptor getDescriptor() + public StatisticAggregationsDescriptor getDescriptor() { return descriptor; } @JsonProperty - public Symbol getRowCountSymbol() + public VariableReferenceExpression getRowCountVariable() { - return rowCountSymbol; + return rowCountVariable; } @JsonProperty @@ -91,9 +91,9 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.of(rowCountSymbol); + return ImmutableList.of(rowCountVariable); } @Override @@ -103,7 +103,7 @@ public PlanNode replaceChildren(List newChildren) getId(), Iterables.getOnlyElement(newChildren), target, - rowCountSymbol, + rowCountVariable, rowCountEnabled, descriptor); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFinishNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFinishNode.java index 6807312e243fb..def625f9830e0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFinishNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFinishNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -35,25 +35,25 @@ public class TableFinishNode { private final PlanNode source; private final WriterTarget target; - private final Symbol rowCountSymbol; + private final VariableReferenceExpression rowCountVariable; private final Optional statisticsAggregation; - private final Optional> statisticsAggregationDescriptor; + private final Optional> statisticsAggregationDescriptor; @JsonCreator public TableFinishNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("target") WriterTarget target, - @JsonProperty("rowCountSymbol") Symbol rowCountSymbol, + @JsonProperty("rowCountVariable") VariableReferenceExpression rowCountVariable, @JsonProperty("statisticsAggregation") Optional statisticsAggregation, - @JsonProperty("statisticsAggregationDescriptor") Optional> statisticsAggregationDescriptor) + @JsonProperty("statisticsAggregationDescriptor") Optional> statisticsAggregationDescriptor) { super(id); checkArgument(target != null || source instanceof TableWriterNode); this.source = requireNonNull(source, "source is null"); this.target = requireNonNull(target, "target is null"); - this.rowCountSymbol = requireNonNull(rowCountSymbol, "rowCountSymbol is null"); + this.rowCountVariable = requireNonNull(rowCountVariable, "rowCountVariable is null"); this.statisticsAggregation = requireNonNull(statisticsAggregation, "statisticsAggregation is null"); this.statisticsAggregationDescriptor = requireNonNull(statisticsAggregationDescriptor, "statisticsAggregationDescriptor is null"); checkArgument(statisticsAggregation.isPresent() == statisticsAggregationDescriptor.isPresent(), "statisticsAggregation and statisticsAggregationDescriptor must both be either present or absent"); @@ -72,9 +72,9 @@ public WriterTarget getTarget() } @JsonProperty - public Symbol getRowCountSymbol() + public VariableReferenceExpression getRowCountVariable() { - return rowCountSymbol; + return rowCountVariable; } @JsonProperty @@ -84,7 +84,7 @@ public Optional getStatisticsAggregation() } @JsonProperty - public Optional> getStatisticsAggregationDescriptor() + public Optional> getStatisticsAggregationDescriptor() { return statisticsAggregationDescriptor; } @@ -96,9 +96,9 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.of(rowCountSymbol); + return ImmutableList.of(rowCountVariable); } @Override @@ -114,7 +114,7 @@ public PlanNode replaceChildren(List newChildren) getId(), Iterables.getOnlyElement(newChildren), target, - rowCountSymbol, + rowCountVariable, statisticsAggregation, statisticsAggregationDescriptor); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableScanNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableScanNode.java index f7623a1da7b38..9ccf70e8ac05f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableScanNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableScanNode.java @@ -17,13 +17,12 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import javax.annotation.concurrent.Immutable; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -38,8 +37,8 @@ public class TableScanNode extends PlanNode { private final TableHandle table; - private final List outputSymbols; - private final Map assignments; + private final Map assignments; + private final List outputVariables; // Used during predicate refinement over multiple passes of predicate pushdown // TODO: think about how to get rid of this in new planner @@ -55,16 +54,16 @@ public class TableScanNode public TableScanNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("table") TableHandle table, - @JsonProperty("outputSymbols") List outputs, - @JsonProperty("assignments") Map assignments, + @JsonProperty("outputVariables") List outputVariables, + @JsonProperty("assignments") Map assignments, @JsonProperty("temporaryTable") boolean temporaryTable) { // This constructor is for JSON deserialization only. Do not use. super(id); this.table = requireNonNull(table, "table is null"); - this.outputSymbols = unmodifiableList(new ArrayList<>(requireNonNull(outputs, "outputs is null"))); + this.outputVariables = unmodifiableList(requireNonNull(outputVariables, "outputVariables is null")); this.assignments = unmodifiableMap(new HashMap<>(requireNonNull(assignments, "assignments is null"))); - checkArgument(assignments.keySet().containsAll(outputs), "assignments does not cover all of outputs"); + checkArgument(assignments.keySet().containsAll(outputVariables), "assignments does not cover all of outputs"); this.temporaryTable = temporaryTable; this.currentConstraint = null; this.enforcedConstraint = null; @@ -73,37 +72,37 @@ public TableScanNode( public TableScanNode( PlanNodeId id, TableHandle table, - List outputs, - Map assignments) + List outputVariables, + Map assignments) { - this(id, table, outputs, assignments, TupleDomain.all(), TupleDomain.all(), false); + this(id, table, outputVariables, assignments, TupleDomain.all(), TupleDomain.all(), false); } public TableScanNode( PlanNodeId id, TableHandle table, - List outputs, - Map assignments, + List outputVariables, + Map assignments, TupleDomain currentConstraint, TupleDomain enforcedConstraint) { - this(id, table, outputs, assignments, currentConstraint, enforcedConstraint, false); + this(id, table, outputVariables, assignments, currentConstraint, enforcedConstraint, false); } public TableScanNode( PlanNodeId id, TableHandle table, - List outputs, - Map assignments, + List outputVariables, + Map assignments, TupleDomain currentConstraint, TupleDomain enforcedConstraint, boolean temporaryTable) { super(id); this.table = requireNonNull(table, "table is null"); - this.outputSymbols = unmodifiableList(new ArrayList<>(requireNonNull(outputs, "outputs is null"))); + this.outputVariables = unmodifiableList(requireNonNull(outputVariables, "outputVariables is null")); this.assignments = unmodifiableMap(new HashMap<>(requireNonNull(assignments, "assignments is null"))); - checkArgument(assignments.keySet().containsAll(outputs), "assignments does not cover all of outputs"); + checkArgument(assignments.keySet().containsAll(outputVariables), "assignments does not cover all of outputs"); this.currentConstraint = requireNonNull(currentConstraint, "currentConstraint is null"); this.enforcedConstraint = requireNonNull(enforcedConstraint, "enforcedConstraint is null"); this.temporaryTable = temporaryTable; @@ -124,8 +123,8 @@ public TableHandle getTable() /** * Get the mapping from symbols to columns */ - @JsonProperty("assignments") - public Map getAssignments() + @JsonProperty + public Map getAssignments() { return assignments; } @@ -173,10 +172,10 @@ public List getSources() } @Override - @JsonProperty("outputSymbols") - public List getOutputSymbols() + @JsonProperty + public List getOutputVariables() { - return outputSymbols; + return outputVariables; } @Override @@ -191,7 +190,7 @@ public String toString() StringBuilder stringBuilder = new StringBuilder(this.getClass().getSimpleName()); stringBuilder.append(" {"); stringBuilder.append("table='").append(table).append('\''); - stringBuilder.append(", outputSymbols='").append(outputSymbols).append('\''); + stringBuilder.append(", outputVariables='").append(outputVariables).append('\''); stringBuilder.append(", assignments='").append(assignments).append('\''); stringBuilder.append(", currentConstraint='").append(currentConstraint).append('\''); stringBuilder.append(", enforcedConstraint='").append(enforcedConstraint).append('\''); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java index e9a17c7133817..267b2995c9b09 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java @@ -20,8 +20,8 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.PartitioningScheme; -import com.facebook.presto.sql.planner.Symbol; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; @@ -43,29 +43,29 @@ public class TableWriterNode { private final PlanNode source; private final WriterTarget target; - private final Symbol rowCountSymbol; - private final Symbol fragmentSymbol; - private final Symbol tableCommitContextSymbol; - private final List columns; + private final VariableReferenceExpression rowCountVariable; + private final VariableReferenceExpression fragmentVariable; + private final VariableReferenceExpression tableCommitContextVariable; + private final List columns; private final List columnNames; private final Optional partitioningScheme; private final Optional statisticsAggregation; - private final Optional> statisticsAggregationDescriptor; - private final List outputs; + private final Optional> statisticsAggregationDescriptor; + private final List outputs; @JsonCreator public TableWriterNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("target") WriterTarget target, - @JsonProperty("rowCountSymbol") Symbol rowCountSymbol, - @JsonProperty("fragmentSymbol") Symbol fragmentSymbol, - @JsonProperty("tableCommitContextSymbol") Symbol tableCommitContextSymbol, - @JsonProperty("columns") List columns, + @JsonProperty("rowCountVariable") VariableReferenceExpression rowCountVariable, + @JsonProperty("fragmentVariable") VariableReferenceExpression fragmentVariable, + @JsonProperty("tableCommitContextVariable") VariableReferenceExpression tableCommitContextVariable, + @JsonProperty("columns") List columns, @JsonProperty("columnNames") List columnNames, @JsonProperty("partitioningScheme") Optional partitioningScheme, @JsonProperty("statisticsAggregation") Optional statisticsAggregation, - @JsonProperty("statisticsAggregationDescriptor") Optional> statisticsAggregationDescriptor) + @JsonProperty("statisticsAggregationDescriptor") Optional> statisticsAggregationDescriptor) { super(id); @@ -75,9 +75,9 @@ public TableWriterNode( this.source = requireNonNull(source, "source is null"); this.target = requireNonNull(target, "target is null"); - this.rowCountSymbol = requireNonNull(rowCountSymbol, "rowCountSymbol is null"); - this.fragmentSymbol = requireNonNull(fragmentSymbol, "fragmentSymbol is null"); - this.tableCommitContextSymbol = requireNonNull(tableCommitContextSymbol, "tableCommitContextSymbol is null"); + this.rowCountVariable = requireNonNull(rowCountVariable, "rowCountVariable is null"); + this.fragmentVariable = requireNonNull(fragmentVariable, "fragmentVariable is null"); + this.tableCommitContextVariable = requireNonNull(tableCommitContextVariable, "tableCommitContextVariable is null"); this.columns = ImmutableList.copyOf(columns); this.columnNames = ImmutableList.copyOf(columnNames); this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null"); @@ -85,12 +85,12 @@ public TableWriterNode( this.statisticsAggregationDescriptor = requireNonNull(statisticsAggregationDescriptor, "statisticsAggregationDescriptor is null"); checkArgument(statisticsAggregation.isPresent() == statisticsAggregationDescriptor.isPresent(), "statisticsAggregation and statisticsAggregationDescriptor must be either present or absent"); - ImmutableList.Builder outputs = ImmutableList.builder() - .add(rowCountSymbol) - .add(fragmentSymbol) - .add(tableCommitContextSymbol); + ImmutableList.Builder outputs = ImmutableList.builder() + .add(rowCountVariable) + .add(fragmentVariable) + .add(tableCommitContextVariable); statisticsAggregation.ifPresent(aggregation -> { - outputs.addAll(aggregation.getGroupingSymbols()); + outputs.addAll(aggregation.getGroupingVariables()); outputs.addAll(aggregation.getAggregations().keySet()); }); this.outputs = outputs.build(); @@ -109,25 +109,25 @@ public WriterTarget getTarget() } @JsonProperty - public Symbol getRowCountSymbol() + public VariableReferenceExpression getRowCountVariable() { - return rowCountSymbol; + return rowCountVariable; } @JsonProperty - public Symbol getFragmentSymbol() + public VariableReferenceExpression getFragmentVariable() { - return fragmentSymbol; + return fragmentVariable; } @JsonProperty - public Symbol getTableCommitContextSymbol() + public VariableReferenceExpression getTableCommitContextVariable() { - return tableCommitContextSymbol; + return tableCommitContextVariable; } @JsonProperty - public List getColumns() + public List getColumns() { return columns; } @@ -151,7 +151,7 @@ public Optional getStatisticsAggregation() } @JsonProperty - public Optional> getStatisticsAggregationDescriptor() + public Optional> getStatisticsAggregationDescriptor() { return statisticsAggregationDescriptor; } @@ -163,7 +163,7 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { return outputs; } @@ -177,7 +177,7 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new TableWriterNode(getId(), Iterables.getOnlyElement(newChildren), target, rowCountSymbol, fragmentSymbol, tableCommitContextSymbol, columns, columnNames, partitioningScheme, statisticsAggregation, statisticsAggregationDescriptor); + return new TableWriterNode(getId(), Iterables.getOnlyElement(newChildren), target, rowCountVariable, fragmentVariable, tableCommitContextVariable, columns, columnNames, partitioningScheme, statisticsAggregation, statisticsAggregationDescriptor); } @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "@type") diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNNode.java index b6d10b6cafad0..3ede56d5e8017 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNNode.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -79,9 +79,9 @@ public PlanNode getSource() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return source.getOutputSymbols(); + return source.getOutputVariables(); } @JsonProperty("count") diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java index 8b07ebe2eaf12..472cb20ccbf63 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.WindowNode.Specification; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -28,7 +28,6 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.concat; import static java.util.Objects.requireNonNull; @Immutable @@ -37,36 +36,36 @@ public final class TopNRowNumberNode { private final PlanNode source; private final Specification specification; - private final Symbol rowNumberSymbol; + private final VariableReferenceExpression rowNumberVariable; private final int maxRowCountPerPartition; private final boolean partial; - private final Optional hashSymbol; + private final Optional hashVariable; @JsonCreator public TopNRowNumberNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("specification") Specification specification, - @JsonProperty("rowNumberSymbol") Symbol rowNumberSymbol, + @JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable, @JsonProperty("maxRowCountPerPartition") int maxRowCountPerPartition, @JsonProperty("partial") boolean partial, - @JsonProperty("hashSymbol") Optional hashSymbol) + @JsonProperty("hashVariable") Optional hashVariable) { super(id); requireNonNull(source, "source is null"); requireNonNull(specification, "specification is null"); checkArgument(specification.getOrderingScheme().isPresent(), "specification orderingScheme is absent"); - requireNonNull(rowNumberSymbol, "rowNumberSymbol is null"); + requireNonNull(rowNumberVariable, "rowNumberVariable is null"); checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0"); - requireNonNull(hashSymbol, "hashSymbol is null"); + requireNonNull(hashVariable, "hashVariable is null"); this.source = source; this.specification = specification; - this.rowNumberSymbol = rowNumberSymbol; + this.rowNumberVariable = rowNumberVariable; this.maxRowCountPerPartition = maxRowCountPerPartition; this.partial = partial; - this.hashSymbol = hashSymbol; + this.hashVariable = hashVariable; } @Override @@ -76,12 +75,14 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { + ImmutableList.Builder builder = ImmutableList.builder().addAll(source.getOutputVariables()); + if (!partial) { - return ImmutableList.copyOf(concat(source.getOutputSymbols(), ImmutableList.of(rowNumberSymbol))); + builder.add(rowNumberVariable); } - return ImmutableList.copyOf(source.getOutputSymbols()); + return builder.build(); } @JsonProperty @@ -96,7 +97,7 @@ public Specification getSpecification() return specification; } - public List getPartitionBy() + public List getPartitionBy() { return specification.getPartitionBy(); } @@ -107,9 +108,9 @@ public OrderingScheme getOrderingScheme() } @JsonProperty - public Symbol getRowNumberSymbol() + public VariableReferenceExpression getRowNumberVariable() { - return rowNumberSymbol; + return rowNumberVariable; } @JsonProperty @@ -125,9 +126,9 @@ public boolean isPartial() } @JsonProperty - public Optional getHashSymbol() + public Optional getHashVariable() { - return hashSymbol; + return hashVariable; } @Override @@ -139,6 +140,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new TopNRowNumberNode(getId(), Iterables.getOnlyElement(newChildren), specification, rowNumberSymbol, maxRowCountPerPartition, partial, hashSymbol); + return new TopNRowNumberNode(getId(), Iterables.getOnlyElement(newChildren), specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnionNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnionNode.java index 6c06ceaa52e23..385f70ad907e3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnionNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnionNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ListMultimap; @@ -31,10 +31,9 @@ public class UnionNode public UnionNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("sources") List sources, - @JsonProperty("outputToInputs") ListMultimap outputToInputs, - @JsonProperty("outputs") List outputs) + @JsonProperty("outputToInputs") ListMultimap outputToInputs) { - super(id, sources, outputToInputs, outputs); + super(id, sources, outputToInputs); } @Override @@ -46,6 +45,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new UnionNode(getId(), newChildren, getSymbolMapping(), getOutputSymbols()); + return new UnionNode(getId(), newChildren, getVariableMapping()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnnestNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnnestNode.java index 302373e6d72ff..187afcce71a89 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnnestNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnnestNode.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -35,39 +35,39 @@ public class UnnestNode extends InternalPlanNode { private final PlanNode source; - private final List replicateSymbols; - private final Map> unnestSymbols; - private final Optional ordinalitySymbol; + private final List replicateVariables; + private final Map> unnestVariables; + private final Optional ordinalityVariable; @JsonCreator public UnnestNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("replicateSymbols") List replicateSymbols, - @JsonProperty("unnestSymbols") Map> unnestSymbols, - @JsonProperty("ordinalitySymbol") Optional ordinalitySymbol) + @JsonProperty("replicateVariables") List replicateVariables, + @JsonProperty("unnestVariables") Map> unnestVariables, + @JsonProperty("ordinalityVariable") Optional ordinalityVariable) { super(id); this.source = requireNonNull(source, "source is null"); - this.replicateSymbols = ImmutableList.copyOf(requireNonNull(replicateSymbols, "replicateSymbols is null")); - checkArgument(source.getOutputSymbols().containsAll(replicateSymbols), "Source does not contain all replicateSymbols"); - requireNonNull(unnestSymbols, "unnestSymbols is null"); - checkArgument(!unnestSymbols.isEmpty(), "unnestSymbols is empty"); - ImmutableMap.Builder> builder = ImmutableMap.builder(); - for (Map.Entry> entry : unnestSymbols.entrySet()) { + this.replicateVariables = ImmutableList.copyOf(requireNonNull(replicateVariables, "replicateVariables is null")); + checkArgument(source.getOutputVariables().containsAll(replicateVariables), "Source does not contain all replicateSymbols"); + requireNonNull(unnestVariables, "unnestVariables is null"); + checkArgument(!unnestVariables.isEmpty(), "unnestVariables is empty"); + ImmutableMap.Builder> builder = ImmutableMap.builder(); + for (Map.Entry> entry : unnestVariables.entrySet()) { builder.put(entry.getKey(), ImmutableList.copyOf(entry.getValue())); } - this.unnestSymbols = builder.build(); - this.ordinalitySymbol = requireNonNull(ordinalitySymbol, "ordinalitySymbol is null"); + this.unnestVariables = builder.build(); + this.ordinalityVariable = requireNonNull(ordinalityVariable, "ordinalityVariable is null"); } @Override - public List getOutputSymbols() + public List getOutputVariables() { - ImmutableList.Builder outputSymbolsBuilder = ImmutableList.builder() - .addAll(replicateSymbols) - .addAll(Iterables.concat(unnestSymbols.values())); - ordinalitySymbol.ifPresent(outputSymbolsBuilder::add); + ImmutableList.Builder outputSymbolsBuilder = ImmutableList.builder() + .addAll(replicateVariables) + .addAll(Iterables.concat(unnestVariables.values())); + ordinalityVariable.ifPresent(outputSymbolsBuilder::add); return outputSymbolsBuilder.build(); } @@ -78,21 +78,21 @@ public PlanNode getSource() } @JsonProperty - public List getReplicateSymbols() + public List getReplicateVariables() { - return replicateSymbols; + return replicateVariables; } @JsonProperty - public Map> getUnnestSymbols() + public Map> getUnnestVariables() { - return unnestSymbols; + return unnestVariables; } @JsonProperty - public Optional getOrdinalitySymbol() + public Optional getOrdinalityVariable() { - return ordinalitySymbol; + return ordinalityVariable; } @Override @@ -110,6 +110,6 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new UnnestNode(getId(), Iterables.getOnlyElement(newChildren), replicateSymbols, unnestSymbols, ordinalitySymbol); + return new UnnestNode(getId(), Iterables.getOnlyElement(newChildren), replicateVariables, unnestVariables, ordinalityVariable); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ValuesNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ValuesNode.java index 4ae16f5f49fdf..f0ea49663bfce 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ValuesNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ValuesNode.java @@ -15,7 +15,7 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -31,29 +31,29 @@ public class ValuesNode extends InternalPlanNode { - private final List outputSymbols; + private final List outputVariables; private final List> rows; @JsonCreator public ValuesNode(@JsonProperty("id") PlanNodeId id, - @JsonProperty("outputSymbols") List outputSymbols, + @JsonProperty("outputVariables") List outputVariables, @JsonProperty("rows") List> rows) { super(id); - this.outputSymbols = ImmutableList.copyOf(outputSymbols); + this.outputVariables = ImmutableList.copyOf(outputVariables); this.rows = listOfListsCopy(rows); for (List row : rows) { - checkArgument(row.size() == outputSymbols.size() || row.size() == 0, - "Expected row to have %s values, but row has %s values", outputSymbols.size(), row.size()); + checkArgument(row.size() == outputVariables.size() || row.size() == 0, + "Expected row to have %s values, but row has %s values", outputVariables.size(), row.size()); } } @Override @JsonProperty - public List getOutputSymbols() + public List getOutputVariables() { - return outputSymbols; + return outputVariables; } @JsonProperty diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java index 4df2ff38f29f5..33506ee2a27ef 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java @@ -16,8 +16,8 @@ import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -35,7 +35,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.concat; import static java.util.Objects.requireNonNull; @Immutable @@ -43,20 +42,20 @@ public class WindowNode extends InternalPlanNode { private final PlanNode source; - private final Set prePartitionedInputs; + private final Set prePartitionedInputs; private final Specification specification; private final int preSortedOrderPrefix; - private final Map windowFunctions; - private final Optional hashSymbol; + private final Map windowFunctions; + private final Optional hashVariable; @JsonCreator public WindowNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("specification") Specification specification, - @JsonProperty("windowFunctions") Map windowFunctions, - @JsonProperty("hashSymbol") Optional hashSymbol, - @JsonProperty("prePartitionedInputs") Set prePartitionedInputs, + @JsonProperty("windowFunctions") Map windowFunctions, + @JsonProperty("hashVariable") Optional hashVariable, + @JsonProperty("prePartitionedInputs") Set prePartitionedInputs, @JsonProperty("preSortedOrderPrefix") int preSortedOrderPrefix) { super(id); @@ -64,7 +63,7 @@ public WindowNode( requireNonNull(source, "source is null"); requireNonNull(specification, "specification is null"); requireNonNull(windowFunctions, "windowFunctions is null"); - requireNonNull(hashSymbol, "hashSymbol is null"); + requireNonNull(hashVariable, "hashVariable is null"); checkArgument(specification.getPartitionBy().containsAll(prePartitionedInputs), "prePartitionedInputs must be contained in partitionBy"); Optional orderingScheme = specification.getOrderingScheme(); checkArgument(preSortedOrderPrefix == 0 || (orderingScheme.isPresent() && preSortedOrderPrefix <= orderingScheme.get().getOrderBy().size()), "Cannot have sorted more symbols than those requested"); @@ -74,7 +73,7 @@ public WindowNode( this.prePartitionedInputs = ImmutableSet.copyOf(prePartitionedInputs); this.specification = specification; this.windowFunctions = ImmutableMap.copyOf(windowFunctions); - this.hashSymbol = hashSymbol; + this.hashVariable = hashVariable; this.preSortedOrderPrefix = preSortedOrderPrefix; } @@ -85,14 +84,17 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { - return ImmutableList.copyOf(concat(source.getOutputSymbols(), windowFunctions.keySet())); + return ImmutableList.builder() + .addAll(source.getOutputVariables()) + .addAll(windowFunctions.keySet()) + .build(); } - public Set getCreatedSymbols() + public Set getCreatedVariable() { - return ImmutableSet.copyOf(windowFunctions.keySet()); + return windowFunctions.keySet(); } @JsonProperty @@ -107,7 +109,7 @@ public Specification getSpecification() return specification; } - public List getPartitionBy() + public List getPartitionBy() { return specification.getPartitionBy(); } @@ -118,7 +120,7 @@ public Optional getOrderingScheme() } @JsonProperty - public Map getWindowFunctions() + public Map getWindowFunctions() { return windowFunctions; } @@ -131,13 +133,13 @@ public List getFrames() } @JsonProperty - public Optional getHashSymbol() + public Optional getHashVariable() { - return hashSymbol; + return hashVariable; } @JsonProperty - public Set getPrePartitionedInputs() + public Set getPrePartitionedInputs() { return prePartitionedInputs; } @@ -157,18 +159,18 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new WindowNode(getId(), Iterables.getOnlyElement(newChildren), specification, windowFunctions, hashSymbol, prePartitionedInputs, preSortedOrderPrefix); + return new WindowNode(getId(), Iterables.getOnlyElement(newChildren), specification, windowFunctions, hashVariable, prePartitionedInputs, preSortedOrderPrefix); } @Immutable public static class Specification { - private final List partitionBy; + private final List partitionBy; private final Optional orderingScheme; @JsonCreator public Specification( - @JsonProperty("partitionBy") List partitionBy, + @JsonProperty("partitionBy") List partitionBy, @JsonProperty("orderingScheme") Optional orderingScheme) { requireNonNull(partitionBy, "partitionBy is null"); @@ -179,7 +181,7 @@ public Specification( } @JsonProperty - public List getPartitionBy() + public List getPartitionBy() { return partitionBy; } @@ -219,9 +221,9 @@ public static class Frame { private final WindowType type; private final BoundType startType; - private final Optional startValue; + private final Optional startValue; private final BoundType endType; - private final Optional endValue; + private final Optional endValue; // This information is only used for printing the plan. private final Optional originalStartValue; @@ -231,9 +233,9 @@ public static class Frame public Frame( @JsonProperty("type") WindowType type, @JsonProperty("startType") BoundType startType, - @JsonProperty("startValue") Optional startValue, + @JsonProperty("startValue") Optional startValue, @JsonProperty("endType") BoundType endType, - @JsonProperty("endValue") Optional endValue, + @JsonProperty("endValue") Optional endValue, @JsonProperty("originalStartValue") Optional originalStartValue, @JsonProperty("originalEndValue") Optional originalEndValue) { @@ -267,7 +269,7 @@ public BoundType getStartType() } @JsonProperty - public Optional getStartValue() + public Optional getStartValue() { return startValue; } @@ -279,7 +281,7 @@ public BoundType getEndType() } @JsonProperty - public Optional getEndValue() + public Optional getEndValue() { return endValue; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/NodeRepresentation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/NodeRepresentation.java index 5f1aa23684a44..3ee136acf0b35 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/NodeRepresentation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/NodeRepresentation.java @@ -16,6 +16,7 @@ import com.facebook.presto.cost.PlanCostEstimate; import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanFragmentId; @@ -32,7 +33,7 @@ public class NodeRepresentation private final String name; private final String type; private final String identifier; - private final List outputs; + private final List outputs; private final List children; private final List remoteSources; private final Optional stats; @@ -46,7 +47,7 @@ public NodeRepresentation( String name, String type, String identifier, - List outputs, + List outputs, Optional stats, List estimatedStats, List estimatedCost, @@ -103,7 +104,7 @@ public String getIdentifier() return identifier; } - public List getOutputs() + public List getOutputs() { return outputs; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 1077f0c354a90..eb9787917456b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -30,11 +30,12 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.Marker; -import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.InterpretedFunctionInvoker; import com.facebook.presto.sql.planner.OrderingScheme; @@ -88,7 +89,6 @@ import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; -import com.facebook.presto.sql.planner.planPrinter.NodeRepresentation.OutputSymbol; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; @@ -180,11 +180,10 @@ public String toJson() return new JsonRenderer().render(representation); } - public static String jsonFragmentPlan(PlanNode root, Map symbols, FunctionManager functionManager, Session session) + public static String jsonFragmentPlan(PlanNode root, Set variables, FunctionManager functionManager, Session session) { - TypeProvider typeProvider = TypeProvider.copyOf(symbols.entrySet().stream() - .distinct() - .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))); + TypeProvider typeProvider = TypeProvider.copyOf(variables.stream() + .collect(toImmutableMap(variable -> new Symbol(variable.getName()), VariableReferenceExpression::getType))); return new PlanPrinter(root, typeProvider, Optional.empty(), functionManager, StatsAndCosts.empty(), session, Optional.empty()).toJson(); } @@ -281,7 +280,7 @@ private static String formatFragment(FunctionManager functionManager, Session se List arguments = partitioningScheme.getPartitioning().getArguments().stream() .map(argument -> { if (argument.isConstant()) { - NullableValue constant = argument.getConstant(); + ConstantExpression constant = argument.getConstant(); String printableValue = castToVarchar(constant.getType(), constant.getValue(), functionManager, session); return constant.getType().getDisplayName() + "(" + printableValue + ")"; } @@ -304,9 +303,9 @@ private static String formatFragment(FunctionManager functionManager, Session se builder.append(indentString(1)).append(format("Stage Execution Strategy: %s\n", fragment.getStageExecutionDescriptor().getStageExecutionStrategy())); TypeProvider typeProvider = TypeProvider.copyOf(allFragments.stream() - .flatMap(f -> f.getSymbols().entrySet().stream()) + .flatMap(f -> f.getVariables().stream()) .distinct() - .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))); + .collect(toImmutableMap(variable -> new Symbol(variable.getName()), VariableReferenceExpression::getType))); builder.append(textLogicalPlan(fragment.getRoot(), typeProvider, Optional.of(fragment.getStageExecutionDescriptor()), functionManager, fragment.getStatsAndCosts(), session, planNodeStats, 1, verbose)) .append("\n"); @@ -319,10 +318,10 @@ public static String graphvizLogicalPlan(PlanNode plan, TypeProvider types, Sess PlanFragment fragment = new PlanFragment( new PlanFragmentId(0), plan, - types.allTypes(), + types.allVariables(), SINGLE_DISTRIBUTION, ImmutableList.of(plan.getId()), - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getOutputVariables()), StageExecutionDescriptor.ungroupedExecution(), false, StatsAndCosts.empty(), @@ -377,7 +376,7 @@ public Void visitJoin(JoinNode node, Void context) else { nodeOutput = addNode(node, node.getType().getJoinLabel(), - format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashSymbol(), node.getRightHashSymbol()))); + format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashVariable(), node.getRightHashVariable()))); } node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetails("Distribution: %s", distributionType)); @@ -409,9 +408,9 @@ public Void visitSemiJoin(SemiJoinNode node, Void context) NodeRepresentation nodeOutput = addNode(node, "SemiJoin", format("[%s = %s]%s", - node.getSourceJoinSymbol(), - node.getFilteringSourceJoinSymbol(), - formatHash(node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol()))); + node.getSourceJoinVariable(), + node.getFilteringSourceJoinVariable(), + formatHash(node.getSourceHashVariable(), node.getFilteringSourceHashVariable()))); node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetailsLine("Distribution: %s", distributionType)); node.getSource().accept(this, context); node.getFilteringSource().accept(this, context); @@ -424,10 +423,10 @@ public Void visitIndexSource(IndexSourceNode node, Void context) { NodeRepresentation nodeOutput = addNode(node, "IndexSource", - format("[%s, lookup = %s]", node.getIndexHandle(), node.getLookupSymbols())); + format("[%s, lookup = %s]", node.getIndexHandle(), node.getLookupVariables())); - for (Map.Entry entry : node.getAssignments().entrySet()) { - if (node.getOutputSymbols().contains(entry.getKey())) { + for (Map.Entry entry : node.getAssignments().entrySet()) { + if (node.getOutputVariables().contains(entry.getKey())) { nodeOutput.appendDetailsLine("%s := %s", entry.getKey(), entry.getValue()); } } @@ -440,13 +439,13 @@ public Void visitIndexJoin(IndexJoinNode node, Void context) List joinExpressions = new ArrayList<>(); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { joinExpressions.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, - clause.getProbe().toSymbolReference(), - clause.getIndex().toSymbolReference())); + new SymbolReference(clause.getProbe().getName()), + new SymbolReference(clause.getIndex().getName()))); } addNode(node, format("%sIndexJoin", node.getType().getJoinLabel()), - format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getProbeHashSymbol(), node.getIndexHashSymbol()))); + format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getProbeHashVariable(), node.getIndexHashVariable()))); node.getProbeSource().accept(this, context); node.getIndexSource().accept(this, context); @@ -467,7 +466,7 @@ public Void visitDistinctLimit(DistinctLimitNode node, Void context) { addNode(node, format("DistinctLimit%s", node.isPartial() ? "Partial" : ""), - format("[%s]%s", node.getLimit(), formatHash(node.getHashSymbol()))); + format("[%s]%s", node.getLimit(), formatHash(node.getHashVariable()))); return processChildren(node, context); } @@ -487,9 +486,9 @@ public Void visitAggregation(AggregationNode node, Void context) } NodeRepresentation nodeOutput = addNode(node, - format("Aggregate%s%s%s", type, key, formatHash(node.getHashSymbol()))); + format("Aggregate%s%s%s", type, key, formatHash(node.getHashVariable()))); - for (Map.Entry entry : node.getAggregations().entrySet()) { + for (Map.Entry entry : node.getAggregations().entrySet()) { nodeOutput.appendDetailsLine("%s := %s", entry.getKey(), formatAggregation(entry.getValue())); } @@ -523,7 +522,7 @@ private String formatAggregation(AggregationNode.Aggregation aggregation) public Void visitGroupId(GroupIdNode node, Void context) { // grouping sets are easier to understand in terms of inputs - List> inputGroupingSetSymbols = node.getGroupingSets().stream() + List> inputGroupingSetSymbols = node.getGroupingSets().stream() .map(set -> set.stream() .map(symbol -> node.getGroupingColumns().get(symbol)) .collect(Collectors.toList())) @@ -531,7 +530,7 @@ public Void visitGroupId(GroupIdNode node, Void context) NodeRepresentation nodeOutput = addNode(node, "GroupId", format("%s", inputGroupingSetSymbols)); - for (Map.Entry mapping : node.getGroupingColumns().entrySet()) { + for (Map.Entry mapping : node.getGroupingColumns().entrySet()) { nodeOutput.appendDetailsLine("%s := %s", mapping.getKey(), mapping.getValue()); } @@ -543,7 +542,7 @@ public Void visitMarkDistinct(MarkDistinctNode node, Void context) { addNode(node, "MarkDistinct", - format("[distinct=%s marker=%s]%s", formatOutputs(types, node.getDistinctSymbols()), node.getMarkerSymbol(), formatHash(node.getHashSymbol()))); + format("[distinct=%s marker=%s]%s", formatOutputs(node.getDistinctVariables()), node.getMarkerVariable(), formatHash(node.getHashVariable()))); return processChildren(node, context); } @@ -555,11 +554,11 @@ public Void visitWindow(WindowNode node, Void context) List args = new ArrayList<>(); if (!partitionBy.isEmpty()) { - List prePartitioned = node.getPartitionBy().stream() + List prePartitioned = node.getPartitionBy().stream() .filter(node.getPrePartitionedInputs()::contains) .collect(toImmutableList()); - List notPrePartitioned = node.getPartitionBy().stream() + List notPrePartitioned = node.getPartitionBy().stream() .filter(column -> !node.getPrePartitionedInputs().contains(column)) .collect(toImmutableList()); @@ -589,9 +588,9 @@ public Void visitWindow(WindowNode node, Void context) .collect(Collectors.joining(", ")))); } - NodeRepresentation nodeOutput = addNode(node, "Window", format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashSymbol()))); + NodeRepresentation nodeOutput = addNode(node, "Window", format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable()))); - for (Map.Entry entry : node.getWindowFunctions().entrySet()) { + for (Map.Entry entry : node.getWindowFunctions().entrySet()) { CallExpression call = entry.getValue().getFunctionCall(); String frameInfo = formatFrame(entry.getValue().getFrame()); @@ -622,9 +621,9 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Void context) NodeRepresentation nodeOutput = addNode(node, "TopNRowNumber", - format("[%s limit %s]%s", Joiner.on(", ").join(args), node.getMaxRowCountPerPartition(), formatHash(node.getHashSymbol()))); + format("[%s limit %s]%s", Joiner.on(", ").join(args), node.getMaxRowCountPerPartition(), formatHash(node.getHashVariable()))); - nodeOutput.appendDetailsLine("%s := %s", node.getRowNumberSymbol(), "row_number()"); + nodeOutput.appendDetailsLine("%s := %s", node.getRowNumberVariable(), "row_number()"); return processChildren(node, context); } @@ -644,8 +643,8 @@ public Void visitRowNumber(RowNumberNode node, Void context) NodeRepresentation nodeOutput = addNode(node, "RowNumber", - format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashSymbol()))); - nodeOutput.appendDetailsLine("%s := %s", node.getRowNumberSymbol(), "row_number()"); + format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable()))); + nodeOutput.appendDetailsLine("%s := %s", node.getRowNumberVariable(), "row_number()"); return processChildren(node, context); } @@ -799,7 +798,7 @@ private void printTableScanInfo(NodeRepresentation nodeOutput, TableScanNode nod } else { // first, print output columns and their constraints - for (Map.Entry assignment : node.getAssignments().entrySet()) { + for (Map.Entry assignment : node.getAssignments().entrySet()) { ColumnHandle column = assignment.getValue(); nodeOutput.appendDetailsLine("%s := %s", assignment.getKey(), column); printConstraint(nodeOutput, column, predicate); @@ -826,7 +825,7 @@ public Void visitUnnest(UnnestNode node, Void context) { addNode(node, "Unnest", - format("[replicate=%s, unnest=%s]", formatOutputs(types, node.getReplicateSymbols()), formatOutputs(types, node.getUnnestSymbols().keySet()))); + format("[replicate=%s, unnest=%s]", formatOutputs(node.getReplicateVariables()), formatOutputs(node.getUnnestVariables().keySet()))); return processChildren(node, context); } @@ -836,9 +835,9 @@ public Void visitOutput(OutputNode node, Void context) NodeRepresentation nodeOutput = addNode(node, "Output", format("[%s]", Joiner.on(", ").join(node.getColumnNames()))); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); - Symbol symbol = node.getOutputSymbols().get(i); - if (!name.equals(symbol.toString())) { - nodeOutput.appendDetailsLine("%s := %s", name, symbol); + VariableReferenceExpression variable = node.getOutputVariables().get(i); + if (!name.equals(variable.toString())) { + nodeOutput.appendDetailsLine("%s := %s", name, variable); } } return processChildren(node, context); @@ -914,8 +913,8 @@ public Void visitTableWriter(TableWriterNode node, Void context) NodeRepresentation nodeOutput = addNode(node, "TableWriter"); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); - Symbol symbol = node.getColumns().get(i); - nodeOutput.appendDetailsLine("%s := %s", name, symbol); + VariableReferenceExpression variable = node.getColumns().get(i); + nodeOutput.appendDetailsLine("%s := %s", name, variable); } int statisticsCollected = node.getStatisticsAggregation() @@ -1057,7 +1056,7 @@ private Void processChildren(PlanNode node, Void context) private void printAssignments(NodeRepresentation nodeOutput, Assignments assignments) { - for (Map.Entry entry : assignments.getMap().entrySet()) { + for (Map.Entry entry : assignments.getMap().entrySet()) { if (entry.getValue() instanceof SymbolReference && ((SymbolReference) entry.getValue()).getName().equals(entry.getKey().getName())) { // skip identity assignments continue; @@ -1160,9 +1159,7 @@ public NodeRepresentation addNode(PlanNode rootNode, String name, String identif name, rootNode.getClass().getSimpleName(), identifier, - rootNode.getOutputSymbols().stream() - .map(s -> new OutputSymbol(s, types.get(s).getDisplayName())) - .collect(toImmutableList()), + rootNode.getOutputVariables(), stats.map(s -> s.get(rootNode.getId())), estimatedStats, estimatedCosts, @@ -1203,24 +1200,24 @@ private static String formatFrame(WindowNode.Frame frame) return builder.toString(); } - private static String formatHash(Optional... hashes) + private static String formatHash(Optional... hashes) { - List symbols = stream(hashes) + List variables = stream(hashes) .filter(Optional::isPresent) .map(Optional::get) .collect(toList()); - if (symbols.isEmpty()) { + if (variables.isEmpty()) { return ""; } - return "[" + Joiner.on(", ").join(symbols) + "]"; + return "[" + Joiner.on(", ").join(variables) + "]"; } - private static String formatOutputs(TypeProvider types, Iterable outputs) + private static String formatOutputs(Iterable outputs) { return Streams.stream(outputs) - .map(input -> input + ":" + types.get(input).getDisplayName()) + .map(input -> input + ":" + input.getType().getDisplayName()) .collect(Collectors.joining(", ")); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java index 1f0622899105a..cc030005fb5c3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java @@ -15,7 +15,6 @@ import com.facebook.presto.cost.PlanCostEstimate; import com.facebook.presto.cost.PlanNodeStatsEstimate; -import com.facebook.presto.sql.planner.Symbol; import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; @@ -61,7 +60,7 @@ private String writeTextOutput(StringBuilder output, PlanRepresentation plan, in .append(node.getIdentifier()) .append(" => [") .append(node.getOutputs().stream() - .map(s -> s.getSymbol() + ":" + s.getType()) + .map(s -> s.getName() + ":" + s.getType().getDisplayName()) .collect(joining(", "))) .append("]\n"); @@ -219,13 +218,9 @@ private String printEstimates(PlanRepresentation plan, NodeRepresentation node) PlanNodeStatsEstimate stats = node.getEstimatedStats().get(i); PlanCostEstimate cost = node.getEstimatedCost().get(i); - List outputSymbols = node.getOutputs().stream() - .map(NodeRepresentation.OutputSymbol::getSymbol) - .collect(toList()); - output.append(format("{rows: %s (%s), cpu: %s, memory: %s, network: %s}", formatAsLong(stats.getOutputRowCount()), - formatEstimateAsDataSize(stats.getOutputSizeInBytes(outputSymbols, plan.getTypes())), + formatEstimateAsDataSize(stats.getOutputSizeInBytes(node.getOutputs())), formatDouble(cost.getCpuCost()), formatDouble(cost.getMaxMemory()), formatDouble(cost.getNetworkCost()))); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java index f8798eb78bc97..46feb3719b999 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java @@ -18,6 +18,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; @@ -112,16 +113,15 @@ public Void visitProject(ProjectNode node, Void context) { visitPlan(node, context); - for (Map.Entry entry : node.getAssignments().entrySet()) { - Type expectedType = types.get(entry.getKey()); + for (Map.Entry entry : node.getAssignments().entrySet()) { if (entry.getValue() instanceof SymbolReference) { SymbolReference symbolReference = (SymbolReference) entry.getValue(); - verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), types.get(Symbol.from(symbolReference)).getTypeSignature()); + verifyTypeSignature(entry.getKey(), types.get(Symbol.from(symbolReference)).getTypeSignature()); continue; } Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, entry.getValue(), emptyList(), warningCollector); Type actualType = expressionTypes.get(NodeRef.of(entry.getValue())); - verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), actualType.getTypeSignature()); + verifyTypeSignature(entry.getKey(), actualType.getTypeSignature()); } return null; @@ -132,67 +132,57 @@ public Void visitUnion(UnionNode node, Void context) { visitPlan(node, context); - ListMultimap symbolMapping = node.getSymbolMapping(); - for (Symbol keySymbol : symbolMapping.keySet()) { - List valueSymbols = symbolMapping.get(keySymbol); - Type expectedType = types.get(keySymbol); - for (Symbol valueSymbol : valueSymbols) { - verifyTypeSignature(keySymbol, expectedType.getTypeSignature(), types.get(valueSymbol).getTypeSignature()); + ListMultimap variableMapping = node.getVariableMapping(); + for (VariableReferenceExpression keyVariable : variableMapping.keySet()) { + List valueVariables = variableMapping.get(keyVariable); + for (VariableReferenceExpression valueVariable : valueVariables) { + verifyTypeSignature(keyVariable, valueVariable.getType().getTypeSignature()); } } return null; } - private void checkWindowFunctions(Map functions) + private void checkWindowFunctions(Map functions) { - for (Map.Entry entry : functions.entrySet()) { + for (Map.Entry entry : functions.entrySet()) { FunctionHandle functionHandle = entry.getValue().getFunctionHandle(); CallExpression call = entry.getValue().getFunctionCall(); - checkTypeSignature(entry.getKey(), metadata.getFunctionManager().getFunctionMetadata(functionHandle).getReturnType()); + verifyTypeSignature(entry.getKey(), metadata.getFunctionManager().getFunctionMetadata(functionHandle).getReturnType()); checkCall(entry.getKey(), call); } } - private void checkTypeSignature(Symbol symbol, TypeSignature actualTypeSignature) + private void checkCall(VariableReferenceExpression variable, CallExpression call) { - TypeSignature expectedTypeSignature = types.get(symbol).getTypeSignature(); - verifyTypeSignature(symbol, expectedTypeSignature, actualTypeSignature); - } - - private void checkCall(Symbol symbol, CallExpression call) - { - Type expectedType = types.get(symbol); Type actualType = call.getType(); - verifyTypeSignature(symbol, expectedType.getTypeSignature(), actualType.getTypeSignature()); + verifyTypeSignature(variable, actualType.getTypeSignature()); } - private void checkFunctionSignature(Map aggregations) + private void checkFunctionSignature(Map aggregations) { - for (Map.Entry entry : aggregations.entrySet()) { - checkTypeSignature(entry.getKey(), metadata.getFunctionManager().getFunctionMetadata(entry.getValue().getFunctionHandle()).getReturnType()); + for (Map.Entry entry : aggregations.entrySet()) { + verifyTypeSignature(entry.getKey(), metadata.getFunctionManager().getFunctionMetadata(entry.getValue().getFunctionHandle()).getReturnType()); } } - private void checkAggregation(Map aggregations) + private void checkAggregation(Map aggregations) { - for (Map.Entry entry : aggregations.entrySet()) { - Symbol symbol = entry.getKey(); + for (Map.Entry entry : aggregations.entrySet()) { verifyTypeSignature( - symbol, - types.get(symbol).getTypeSignature(), + entry.getKey(), metadata.getFunctionManager().getFunctionMetadata(entry.getValue().getFunctionHandle()).getReturnType()); // TODO check if the argument type agrees with function handle (will be added once Aggregation is using CallExpression). } } - private void verifyTypeSignature(Symbol symbol, TypeSignature expected, TypeSignature actual) + private void verifyTypeSignature(VariableReferenceExpression variable, TypeSignature actual) { // UNKNOWN should be considered as a wildcard type, which matches all the other types TypeManager typeManager = metadata.getTypeManager(); - if (!actual.equals(UNKNOWN.getTypeSignature()) && !typeManager.isTypeOnlyCoercion(typeManager.getType(actual), typeManager.getType(expected))) { - checkArgument(expected.equals(actual), "type of symbol '%s' is expected to be %s, but the actual type is %s", symbol, expected, actual); + if (!actual.equals(UNKNOWN.getTypeSignature()) && !typeManager.isTypeOnlyCoercion(typeManager.getType(actual), variable.getType())) { + checkArgument(variable.getType().getTypeSignature().equals(actual), "type of variable '%s' is expected to be %s, but the actual type is %s", variable.getName(), variable.getType(), actual); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index 24ab9e2b2bf59..ef31ea131211a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -16,8 +16,8 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.WindowNodeUtil; @@ -74,13 +74,14 @@ import java.util.Optional; import java.util.Set; -import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractUnique; +import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractUniqueVariables; import static com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer.IndexKeyTracer; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; /** * Ensures that all dependencies (i.e., symbols in expressions) for a plan node are provided by its source nodes @@ -91,46 +92,53 @@ public final class ValidateDependenciesChecker @Override public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) { - validate(plan); + validate(plan, types); } - public static void validate(PlanNode plan) + public static void validate(PlanNode plan, TypeProvider types) { - plan.accept(new Visitor(), ImmutableSet.of()); + plan.accept(new Visitor(types), ImmutableSet.of()); } private static class Visitor - extends InternalPlanVisitor> + extends InternalPlanVisitor> { + private final TypeProvider types; + + public Visitor(TypeProvider types) + { + this.types = requireNonNull(types, "types is null"); + } + @Override - protected Void visitPlan(PlanNode node, Set boundSymbols) + protected Void visitPlan(PlanNode node, Set boundVariables) { throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName()); } @Override - public Void visitExplainAnalyze(ExplainAnalyzeNode node, Set boundSymbols) + public Void visitExplainAnalyze(ExplainAnalyzeNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child return null; } @Override - public Void visitAggregation(AggregationNode node, Set boundSymbols) + public Void visitAggregation(AggregationNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - Set inputs = createInputs(source, boundSymbols); - checkDependencies(inputs, node.getGroupingKeys(), "Invalid node. Grouping key symbols (%s) not in source plan output (%s)", node.getGroupingKeys(), node.getSource().getOutputSymbols()); + Set inputs = createInputs(source, boundVariables); + checkDependencies(inputs, node.getGroupingKeys(), "Invalid node. Grouping key variables (%s) not in source plan output (%s)", node.getGroupingKeys(), node.getSource().getOutputVariables()); for (Aggregation aggregation : node.getAggregations().values()) { - Set dependencies = extractUnique(aggregation); - checkDependencies(inputs, dependencies, "Invalid node. Aggregation dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); + Set dependencies = extractUniqueVariables(aggregation, types); + checkDependencies(inputs, dependencies, "Invalid node. Aggregation dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputVariables()); aggregation.getMask().ifPresent(mask -> { - checkDependencies(inputs, ImmutableSet.of(mask), "Invalid node. Aggregation mask symbol (%s) not in source plan output (%s)", mask, node.getSource().getOutputSymbols()); + checkDependencies(inputs, ImmutableSet.of(mask), "Invalid node. Aggregation mask symbol (%s) not in source plan output (%s)", mask, node.getSource().getOutputVariables()); }); } @@ -138,45 +146,45 @@ public Void visitAggregation(AggregationNode node, Set boundSymbols) } @Override - public Void visitGroupId(GroupIdNode node, Set boundSymbols) + public Void visitGroupId(GroupIdNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - checkDependencies(source.getOutputSymbols(), node.getInputSymbols(), "Invalid node. Grouping symbols (%s) not in source plan output (%s)", node.getInputSymbols(), source.getOutputSymbols()); + checkDependencies(source.getOutputVariables(), node.getInputVariables(), "Invalid node. Grouping symbols (%s) not in source plan output (%s)", node.getInputVariables(), source.getOutputVariables()); return null; } @Override - public Void visitMarkDistinct(MarkDistinctNode node, Set boundSymbols) + public Void visitMarkDistinct(MarkDistinctNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - checkDependencies(source.getOutputSymbols(), node.getDistinctSymbols(), "Invalid node. Mark distinct symbols (%s) not in source plan output (%s)", node.getDistinctSymbols(), source.getOutputSymbols()); + checkDependencies(source.getOutputVariables(), node.getDistinctVariables(), "Invalid node. Mark distinct symbols (%s) not in source plan output (%s)", node.getDistinctVariables(), source.getOutputVariables()); return null; } @Override - public Void visitWindow(WindowNode node, Set boundSymbols) + public Void visitWindow(WindowNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - Set inputs = createInputs(source, boundSymbols); + Set inputs = createInputs(source, boundVariables); - checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols()); + checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputVariables()); if (node.getOrderingScheme().isPresent()) { checkDependencies( inputs, node.getOrderingScheme().get().getOrderBy(), "Invalid node. Order by symbols (%s) not in source plan output (%s)", - node.getOrderingScheme().get().getOrderBy(), node.getSource().getOutputSymbols()); + node.getOrderingScheme().get().getOrderBy(), node.getSource().getOutputVariables()); } - ImmutableList.Builder bounds = ImmutableList.builder(); + ImmutableList.Builder bounds = ImmutableList.builder(); for (WindowNode.Frame frame : node.getFrames()) { if (frame.getStartValue().isPresent()) { bounds.add(frame.getStartValue().get()); @@ -185,83 +193,90 @@ public Void visitWindow(WindowNode node, Set boundSymbols) bounds.add(frame.getEndValue().get()); } } - checkDependencies(inputs, bounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)", bounds.build(), node.getSource().getOutputSymbols()); + checkDependencies(inputs, bounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)", bounds.build(), node.getSource().getOutputVariables()); for (WindowNode.Function function : node.getWindowFunctions().values()) { - Set dependencies = WindowNodeUtil.extractWindowFunctionUnique(function); - checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); + Set dependencies = WindowNodeUtil.extractWindowFunctionUniqueVariables(function, types); + checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputVariables()); } return null; } @Override - public Void visitTopNRowNumber(TopNRowNumberNode node, Set boundSymbols) + public Void visitTopNRowNumber(TopNRowNumberNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - Set inputs = createInputs(source, boundSymbols); - checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols()); + Set inputs = createInputs(source, boundVariables); + checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputVariables()); checkDependencies( inputs, node.getOrderingScheme().getOrderBy(), "Invalid node. Order by symbols (%s) not in source plan output (%s)", - node.getOrderingScheme().getOrderBy(), node.getSource().getOutputSymbols()); + node.getOrderingScheme().getOrderBy(), node.getSource().getOutputVariables()); return null; } @Override - public Void visitRowNumber(RowNumberNode node, Set boundSymbols) + public Void visitRowNumber(RowNumberNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - checkDependencies(source.getOutputSymbols(), node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols()); + checkDependencies(source.getOutputVariables(), node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), source.getOutputVariables()); return null; } @Override - public Void visitFilter(FilterNode node, Set boundSymbols) + public Void visitFilter(FilterNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - Set inputs = createInputs(source, boundSymbols); - checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols()); + Set inputs = createInputs(source, boundVariables); + checkDependencies(inputs, node.getOutputVariables(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputVariables(), node.getSource().getOutputVariables()); - Set dependencies; + // Only verify names here as filter expression would contain type cast, which will be translated to an non-existent variable in + // SqlToRowExpressionTranslator + // TODO https://github.com/prestodb/presto/issues/12892 + Set dependencies; if (isExpression(node.getPredicate())) { - dependencies = SymbolsExtractor.extractUnique(castToExpression(node.getPredicate())); + dependencies = SymbolsExtractor.extractUniqueVariable(castToExpression(node.getPredicate()), types).stream().map(VariableReferenceExpression::getName).collect(toImmutableSet()); } else { - dependencies = SymbolsExtractor.extractUnique(node.getPredicate()); + dependencies = SymbolsExtractor.extractUniqueVariable(node.getPredicate()).stream().map(VariableReferenceExpression::getName).collect(toImmutableSet()); } - checkDependencies(inputs, dependencies, "Invalid node. Predicate dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); + checkArgument( + inputs.stream().map(VariableReferenceExpression::getName).collect(toImmutableSet()).containsAll(dependencies), + "Symbol from filter (%s) not in sources (%s)", + dependencies, + inputs); return null; } @Override - public Void visitSample(SampleNode node, Set boundSymbols) + public Void visitSample(SampleNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child return null; } @Override - public Void visitProject(ProjectNode node, Set boundSymbols) + public Void visitProject(ProjectNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - Set inputs = createInputs(source, boundSymbols); + Set inputs = createInputs(source, boundVariables); for (Expression expression : node.getAssignments().getExpressions()) { - Set dependencies = SymbolsExtractor.extractUnique(expression); + Set dependencies = SymbolsExtractor.extractUniqueVariable(expression, types); checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } @@ -269,146 +284,149 @@ public Void visitProject(ProjectNode node, Set boundSymbols) } @Override - public Void visitTopN(TopNNode node, Set boundSymbols) + public Void visitTopN(TopNNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - Set inputs = createInputs(source, boundSymbols); - checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols()); + Set inputs = createInputs(source, boundVariables); + checkDependencies(inputs, node.getOutputVariables(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputVariables(), node.getSource().getOutputVariables()); checkDependencies( inputs, node.getOrderingScheme().getOrderBy(), "Invalid node. Order by dependencies (%s) not in source plan output (%s)", node.getOrderingScheme().getOrderBy(), - node.getSource().getOutputSymbols()); + node.getSource().getOutputVariables()); return null; } @Override - public Void visitSort(SortNode node, Set boundSymbols) + public Void visitSort(SortNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - Set inputs = createInputs(source, boundSymbols); - checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols()); + Set inputs = createInputs(source, boundVariables); + checkDependencies(inputs, node.getOutputVariables(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputVariables(), node.getSource().getOutputVariables()); checkDependencies( inputs, node.getOrderingScheme().getOrderBy(), "Invalid node. Order by dependencies (%s) not in source plan output (%s)", - node.getOrderingScheme().getOrderBy(), node.getSource().getOutputSymbols()); + node.getOrderingScheme().getOrderBy(), node.getSource().getOutputVariables()); return null; } @Override - public Void visitOutput(OutputNode node, Set boundSymbols) + public Void visitOutput(OutputNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - checkDependencies(source.getOutputSymbols(), node.getOutputSymbols(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputSymbols(), source.getOutputSymbols()); + checkDependencies(source.getOutputVariables(), node.getOutputVariables(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputVariables(), source.getOutputVariables()); return null; } @Override - public Void visitLimit(LimitNode node, Set boundSymbols) + public Void visitLimit(LimitNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child return null; } @Override - public Void visitDistinctLimit(DistinctLimitNode node, Set boundSymbols) + public Void visitDistinctLimit(DistinctLimitNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - checkDependencies(source.getOutputSymbols(), node.getOutputSymbols(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputSymbols(), source.getOutputSymbols()); + checkDependencies(source.getOutputVariables(), node.getOutputVariables(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputVariables(), source.getOutputVariables()); return null; } @Override - public Void visitJoin(JoinNode node, Set boundSymbols) + public Void visitJoin(JoinNode node, Set boundVariables) { - node.getLeft().accept(this, boundSymbols); - node.getRight().accept(this, boundSymbols); + node.getLeft().accept(this, boundVariables); + node.getRight().accept(this, boundVariables); - Set leftInputs = createInputs(node.getLeft(), boundSymbols); - Set rightInputs = createInputs(node.getRight(), boundSymbols); - Set allInputs = ImmutableSet.builder() + Set leftInputs = createInputs(node.getLeft(), boundVariables); + Set rightInputs = createInputs(node.getRight(), boundVariables); + Set allInputs = ImmutableSet.builder() .addAll(leftInputs) .addAll(rightInputs) .build(); for (JoinNode.EquiJoinClause clause : node.getCriteria()) { - checkArgument(leftInputs.contains(clause.getLeft()), "Symbol from join clause (%s) not in left source (%s)", clause.getLeft(), node.getLeft().getOutputSymbols()); - checkArgument(rightInputs.contains(clause.getRight()), "Symbol from join clause (%s) not in right source (%s)", clause.getRight(), node.getRight().getOutputSymbols()); + checkArgument(leftInputs.contains(clause.getLeft()), "Symbol from join clause (%s) not in left source (%s)", clause.getLeft(), node.getLeft().getOutputVariables()); + checkArgument(rightInputs.contains(clause.getRight()), "Symbol from join clause (%s) not in right source (%s)", clause.getRight(), node.getRight().getOutputVariables()); } node.getFilter().ifPresent(predicate -> { - Set predicateSymbols; + // Only verify names here as filter expression would contain type cast, which will be translated to an non-existent variable in + // SqlToRowExpressionTranslator + // TODO https://github.com/prestodb/presto/issues/12892 + Set predicateVariables; if (isExpression(predicate)) { - predicateSymbols = SymbolsExtractor.extractUnique(castToExpression(predicate)); + predicateVariables = SymbolsExtractor.extractUniqueVariable(castToExpression(predicate), types).stream().map(VariableReferenceExpression::getName).collect(toImmutableSet()); } else { - predicateSymbols = SymbolsExtractor.extractUnique(predicate); + predicateVariables = SymbolsExtractor.extractUniqueVariable(predicate).stream().map(VariableReferenceExpression::getName).collect(toImmutableSet()); } checkArgument( - allInputs.containsAll(predicateSymbols), + allInputs.stream().map(VariableReferenceExpression::getName).collect(toImmutableSet()).containsAll(predicateVariables), "Symbol from filter (%s) not in sources (%s)", - predicateSymbols, + predicateVariables, allInputs); }); - checkLeftOutputSymbolsBeforeRight(node.getLeft().getOutputSymbols(), node.getOutputSymbols()); + checkLeftOutputVariablesBeforeRight(node.getLeft().getOutputVariables(), node.getOutputVariables()); return null; } @Override - public Void visitSemiJoin(SemiJoinNode node, Set boundSymbols) + public Void visitSemiJoin(SemiJoinNode node, Set boundVariables) { - node.getSource().accept(this, boundSymbols); - node.getFilteringSource().accept(this, boundSymbols); + node.getSource().accept(this, boundVariables); + node.getFilteringSource().accept(this, boundVariables); - checkArgument(node.getSource().getOutputSymbols().contains(node.getSourceJoinSymbol()), "Symbol from semi join clause (%s) not in source (%s)", node.getSourceJoinSymbol(), node.getSource().getOutputSymbols()); - checkArgument(node.getFilteringSource().getOutputSymbols().contains(node.getFilteringSourceJoinSymbol()), "Symbol from semi join clause (%s) not in filtering source (%s)", node.getSourceJoinSymbol(), node.getFilteringSource().getOutputSymbols()); + checkArgument(node.getSource().getOutputVariables().contains(node.getSourceJoinVariable()), "Symbol from semi join clause (%s) not in source (%s)", node.getSourceJoinVariable(), node.getSource().getOutputVariables()); + checkArgument(node.getFilteringSource().getOutputVariables().contains(node.getFilteringSourceJoinVariable()), "Symbol from semi join clause (%s) not in filtering source (%s)", node.getSourceJoinVariable(), node.getFilteringSource().getOutputVariables()); - Set outputs = createInputs(node, boundSymbols); - checkArgument(outputs.containsAll(node.getSource().getOutputSymbols()), "Semi join output symbols (%s) must contain all of the source symbols (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols()); + Set outputs = createInputs(node, boundVariables); + checkArgument(outputs.containsAll(node.getSource().getOutputVariables()), "Semi join output symbols (%s) must contain all of the source symbols (%s)", node.getOutputVariables(), node.getSource().getOutputVariables()); checkArgument(outputs.contains(node.getSemiJoinOutput()), "Semi join output symbols (%s) must contain join result (%s)", - node.getOutputSymbols(), + node.getOutputVariables(), node.getSemiJoinOutput()); return null; } @Override - public Void visitSpatialJoin(SpatialJoinNode node, Set boundSymbols) + public Void visitSpatialJoin(SpatialJoinNode node, Set boundVariables) { - node.getLeft().accept(this, boundSymbols); - node.getRight().accept(this, boundSymbols); + node.getLeft().accept(this, boundVariables); + node.getRight().accept(this, boundVariables); - Set leftInputs = createInputs(node.getLeft(), boundSymbols); - Set rightInputs = createInputs(node.getRight(), boundSymbols); - Set allInputs = ImmutableSet.builder() + Set leftInputs = createInputs(node.getLeft(), boundVariables); + Set rightInputs = createInputs(node.getRight(), boundVariables); + Set allInputs = ImmutableSet.builder() .addAll(leftInputs) .addAll(rightInputs) .build(); - Set predicateSymbols; + Set predicateSymbols; if (isExpression(node.getFilter())) { - predicateSymbols = SymbolsExtractor.extractUnique(castToExpression(node.getFilter())); + predicateSymbols = SymbolsExtractor.extractUniqueVariable(castToExpression(node.getFilter()), types); } else { - predicateSymbols = SymbolsExtractor.extractUnique(node.getFilter()); + predicateSymbols = SymbolsExtractor.extractUniqueVariable(node.getFilter()); } checkArgument( @@ -417,232 +435,244 @@ public Void visitSpatialJoin(SpatialJoinNode node, Set boundSymbols) predicateSymbols, allInputs); - checkLeftOutputSymbolsBeforeRight(node.getLeft().getOutputSymbols(), node.getOutputSymbols()); + checkLeftOutputVariablesBeforeRight(node.getLeft().getOutputVariables(), node.getOutputVariables()); return null; } - private void checkLeftOutputSymbolsBeforeRight(List leftSymbols, List outputSymbols) + private void checkLeftOutputVariablesBeforeRight(List leftVariables, List outputVariables) { int leftMaxPosition = -1; Optional rightMinPosition = Optional.empty(); - Set leftSymbolsSet = new HashSet<>(leftSymbols); - for (int i = 0; i < outputSymbols.size(); i++) { - Symbol symbol = outputSymbols.get(i); - if (leftSymbolsSet.contains(symbol)) { + Set leftVariablesSet = new HashSet<>(leftVariables); + for (int i = 0; i < outputVariables.size(); i++) { + VariableReferenceExpression variable = outputVariables.get(i); + if (leftVariablesSet.contains(variable)) { leftMaxPosition = i; } else if (!rightMinPosition.isPresent()) { rightMinPosition = Optional.of(i); } } - checkState(!rightMinPosition.isPresent() || rightMinPosition.get() > leftMaxPosition, "Not all left output symbols are before right output symbols"); + checkState(!rightMinPosition.isPresent() || rightMinPosition.get() > leftMaxPosition, "Not all left output variables are before right output variables"); } @Override - public Void visitIndexJoin(IndexJoinNode node, Set boundSymbols) + public Void visitIndexJoin(IndexJoinNode node, Set boundVariables) { - node.getProbeSource().accept(this, boundSymbols); - node.getIndexSource().accept(this, boundSymbols); + node.getProbeSource().accept(this, boundVariables); + node.getIndexSource().accept(this, boundVariables); - Set probeInputs = createInputs(node.getProbeSource(), boundSymbols); - Set indexSourceInputs = createInputs(node.getIndexSource(), boundSymbols); + Set probeInputs = createInputs(node.getProbeSource(), boundVariables); + Set indexSourceInputs = createInputs(node.getIndexSource(), boundVariables); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { - checkArgument(probeInputs.contains(clause.getProbe()), "Probe symbol from index join clause (%s) not in probe source (%s)", clause.getProbe(), node.getProbeSource().getOutputSymbols()); - checkArgument(indexSourceInputs.contains(clause.getIndex()), "Index symbol from index join clause (%s) not in index source (%s)", clause.getIndex(), node.getIndexSource().getOutputSymbols()); + checkArgument(probeInputs.contains(clause.getProbe()), "Probe variable from index join clause (%s) not in probe source (%s)", clause.getProbe(), node.getProbeSource().getOutputVariables()); + checkArgument(indexSourceInputs.contains(clause.getIndex()), "Index variable from index join clause (%s) not in index source (%s)", clause.getIndex(), node.getIndexSource().getOutputVariables()); } - Set lookupSymbols = node.getCriteria().stream() + Set lookupVariables = node.getCriteria().stream() .map(IndexJoinNode.EquiJoinClause::getIndex) .collect(toImmutableSet()); - Map trace = IndexKeyTracer.trace(node.getIndexSource(), lookupSymbols); - checkArgument(!trace.isEmpty() && lookupSymbols.containsAll(trace.keySet()), + Map trace = IndexKeyTracer.trace(node.getIndexSource(), lookupVariables); + checkArgument(!trace.isEmpty() && lookupVariables.containsAll(trace.keySet()), "Index lookup symbols are not traceable to index source: %s", - lookupSymbols); + lookupVariables); return null; } @Override - public Void visitIndexSource(IndexSourceNode node, Set boundSymbols) + public Void visitIndexSource(IndexSourceNode node, Set boundVariables) { - checkDependencies(node.getOutputSymbols(), node.getLookupSymbols(), "Lookup symbols must be part of output symbols"); - checkDependencies(node.getAssignments().keySet(), node.getOutputSymbols(), "Assignments must contain mappings for output symbols"); + checkDependencies( + node.getOutputVariables(), + node.getLookupVariables(), + "Lookup variables must be part of output symbols"); + checkDependencies( + node.getAssignments().keySet(), + node.getOutputVariables(), + "Assignments must contain mappings for output symbols"); return null; } @Override - public Void visitTableScan(TableScanNode node, Set boundSymbols) + public Void visitTableScan(TableScanNode node, Set boundVariables) { //We don't have to do a check here as TableScanNode has no dependencies. return null; } @Override - public Void visitValues(ValuesNode node, Set boundSymbols) + public Void visitValues(ValuesNode node, Set boundVariables) { - Set correlatedDependencies = SymbolsExtractor.extractUnique(node); + Set correlatedDependencies = SymbolsExtractor.extractUniqueVariable(node, types); checkDependencies( - boundSymbols, + boundVariables, correlatedDependencies, "Invalid node. Expression correlated dependencies (%s) not satisfied by (%s)", correlatedDependencies, - boundSymbols); + boundVariables); return null; } @Override - public Void visitUnnest(UnnestNode node, Set boundSymbols) + public Void visitUnnest(UnnestNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); + source.accept(this, boundVariables); - Set required = ImmutableSet.builder() - .addAll(node.getReplicateSymbols()) - .addAll(node.getUnnestSymbols().keySet()) + Set required = ImmutableSet.builder() + .addAll(node.getReplicateVariables()) + .addAll(node.getUnnestVariables().keySet()) .build(); - checkDependencies(source.getOutputSymbols(), required, "Invalid node. Dependencies (%s) not in source plan output (%s)", required, source.getOutputSymbols()); + checkDependencies(source.getOutputVariables(), required, "Invalid node. Dependencies (%s) not in source plan output (%s)", required, source.getOutputVariables()); return null; } @Override - public Void visitRemoteSource(RemoteSourceNode node, Set boundSymbols) + public Void visitRemoteSource(RemoteSourceNode node, Set boundVariables) { return null; } @Override - public Void visitExchange(ExchangeNode node, Set boundSymbols) + public Void visitExchange(ExchangeNode node, Set boundVariables) { for (int i = 0; i < node.getSources().size(); i++) { PlanNode subplan = node.getSources().get(i); - checkDependencies(subplan.getOutputSymbols(), node.getInputs().get(i), "EXCHANGE subplan must provide all of the necessary symbols"); - subplan.accept(this, boundSymbols); // visit child + checkDependencies( + subplan.getOutputVariables(), + node.getInputs().get(i), + "EXCHANGE subplan must provide all of the necessary symbols"); + subplan.accept(this, boundVariables); // visit child } - checkDependencies(node.getOutputSymbols(), node.getPartitioningScheme().getOutputLayout(), "EXCHANGE must provide all of the necessary symbols for partition function"); + checkDependencies( + node.getOutputVariables(), + node.getPartitioningScheme().getOutputLayout(), + "EXCHANGE must provide all of the necessary symbols for partition function"); return null; } @Override - public Void visitTableWriter(TableWriterNode node, Set boundSymbols) + public Void visitTableWriter(TableWriterNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child return null; } @Override - public Void visitDelete(DeleteNode node, Set boundSymbols) + public Void visitDelete(DeleteNode node, Set boundVariables) { PlanNode source = node.getSource(); - source.accept(this, boundSymbols); // visit child + source.accept(this, boundVariables); // visit child - checkArgument(source.getOutputSymbols().contains(node.getRowId()), "Invalid node. Row ID symbol (%s) is not in source plan output (%s)", node.getRowId(), node.getSource().getOutputSymbols()); + checkArgument(source.getOutputVariables().contains(node.getRowId()), "Invalid node. Row ID symbol (%s) is not in source plan output (%s)", node.getRowId(), node.getSource().getOutputVariables()); return null; } @Override - public Void visitMetadataDelete(MetadataDeleteNode node, Set boundSymbols) + public Void visitMetadataDelete(MetadataDeleteNode node, Set boundVariables) { return null; } @Override - public Void visitStatisticsWriterNode(StatisticsWriterNode node, Set boundSymbols) + public Void visitStatisticsWriterNode(StatisticsWriterNode node, Set boundVariables) { - node.getSource().accept(this, boundSymbols); // visit child + node.getSource().accept(this, boundVariables); // visit child - StatisticAggregationsDescriptor descriptor = node.getDescriptor(); - Set dependencies = ImmutableSet.builder() + StatisticAggregationsDescriptor descriptor = node.getDescriptor(); + Set dependencies = ImmutableSet.builder() .addAll(descriptor.getGrouping().values()) .addAll(descriptor.getColumnStatistics().values()) .addAll(descriptor.getTableStatistics().values()) .build(); - List outputSymbols = node.getSource().getOutputSymbols(); - checkDependencies(dependencies, dependencies, "Invalid node. Dependencies (%s) not in source plan output (%s)", dependencies, outputSymbols); + List outputVariables = node.getSource().getOutputVariables(); + checkDependencies(dependencies, dependencies, "Invalid node. Dependencies (%s) not in source plan output (%s)", dependencies, outputVariables); return null; } @Override - public Void visitTableFinish(TableFinishNode node, Set boundSymbols) + public Void visitTableFinish(TableFinishNode node, Set boundVariables) { - node.getSource().accept(this, boundSymbols); // visit child + node.getSource().accept(this, boundVariables); // visit child return null; } @Override - public Void visitUnion(UnionNode node, Set boundSymbols) + public Void visitUnion(UnionNode node, Set boundVariables) { - return visitSetOperation(node, boundSymbols); + return visitSetOperation(node, boundVariables); } - private Void visitSetOperation(SetOperationNode node, Set boundSymbols) + private Void visitSetOperation(SetOperationNode node, Set boundVariables) { for (int i = 0; i < node.getSources().size(); i++) { PlanNode subplan = node.getSources().get(i); - checkDependencies(subplan.getOutputSymbols(), node.sourceOutputLayout(i), "%s subplan must provide all of the necessary symbols", node.getClass().getSimpleName()); - subplan.accept(this, boundSymbols); // visit child + checkDependencies(subplan.getOutputVariables(), node.sourceOutputLayout(i), "%s subplan must provide all of the necessary symbols", node.getClass().getSimpleName()); + subplan.accept(this, boundVariables); // visit child } return null; } @Override - public Void visitIntersect(IntersectNode node, Set boundSymbols) + public Void visitIntersect(IntersectNode node, Set boundVariables) { - return visitSetOperation(node, boundSymbols); + return visitSetOperation(node, boundVariables); } @Override - public Void visitExcept(ExceptNode node, Set boundSymbols) + public Void visitExcept(ExceptNode node, Set boundVariables) { - return visitSetOperation(node, boundSymbols); + return visitSetOperation(node, boundVariables); } @Override - public Void visitEnforceSingleRow(EnforceSingleRowNode node, Set boundSymbols) + public Void visitEnforceSingleRow(EnforceSingleRowNode node, Set boundVariables) { - node.getSource().accept(this, boundSymbols); // visit child + node.getSource().accept(this, boundVariables); // visit child return null; } @Override - public Void visitAssignUniqueId(AssignUniqueId node, Set boundSymbols) + public Void visitAssignUniqueId(AssignUniqueId node, Set boundVariables) { - node.getSource().accept(this, boundSymbols); // visit child + node.getSource().accept(this, boundVariables); // visit child return null; } @Override - public Void visitApply(ApplyNode node, Set boundSymbols) + public Void visitApply(ApplyNode node, Set boundVariables) { - Set subqueryCorrelation = ImmutableSet.builder() - .addAll(boundSymbols) + Set subqueryCorrelation = ImmutableSet.builder() + .addAll(boundVariables) .addAll(node.getCorrelation()) .build(); - node.getInput().accept(this, boundSymbols); // visit child + node.getInput().accept(this, boundVariables); // visit child node.getSubquery().accept(this, subqueryCorrelation); // visit child - checkDependencies(node.getInput().getOutputSymbols(), node.getCorrelation(), "APPLY input must provide all the necessary correlation symbols for subquery"); - checkDependencies(SymbolsExtractor.extractUnique(node.getSubquery()), node.getCorrelation(), "not all APPLY correlation symbols are used in subquery"); + checkDependencies(node.getInput().getOutputVariables(), node.getCorrelation(), "APPLY input must provide all the necessary correlation variables for subquery"); + checkDependencies(SymbolsExtractor.extractUniqueVariable(node.getSubquery(), types), node.getCorrelation(), "not all APPLY correlation symbols are used in subquery"); - ImmutableSet inputs = ImmutableSet.builder() - .addAll(createInputs(node.getSubquery(), boundSymbols)) - .addAll(createInputs(node.getInput(), boundSymbols)) + ImmutableSet inputs = ImmutableSet.builder() + .addAll(createInputs(node.getSubquery(), boundVariables)) + .addAll(createInputs(node.getInput(), boundVariables)) .build(); for (Expression expression : node.getSubqueryAssignments().getExpressions()) { - Set dependencies = SymbolsExtractor.extractUnique(expression); + Set dependencies = SymbolsExtractor.extractUniqueVariable(expression, types); checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } @@ -650,38 +680,38 @@ public Void visitApply(ApplyNode node, Set boundSymbols) } @Override - public Void visitLateralJoin(LateralJoinNode node, Set boundSymbols) + public Void visitLateralJoin(LateralJoinNode node, Set boundVariables) { - Set subqueryCorrelation = ImmutableSet.builder() - .addAll(boundSymbols) + Set subqueryCorrelation = ImmutableSet.builder() + .addAll(boundVariables) .addAll(node.getCorrelation()) .build(); - node.getInput().accept(this, boundSymbols); // visit child + node.getInput().accept(this, boundVariables); // visit child node.getSubquery().accept(this, subqueryCorrelation); // visit child checkDependencies( - node.getInput().getOutputSymbols(), + node.getInput().getOutputVariables(), node.getCorrelation(), "LATERAL input must provide all the necessary correlation symbols for subquery"); checkDependencies( - SymbolsExtractor.extractUnique(node.getSubquery()), + SymbolsExtractor.extractUniqueVariable(node.getSubquery(), types), node.getCorrelation(), "not all LATERAL correlation symbols are used in subquery"); return null; } - private static ImmutableSet createInputs(PlanNode source, Set boundSymbols) + private static ImmutableSet createInputs(PlanNode source, Set boundVariables) { - return ImmutableSet.builder() - .addAll(source.getOutputSymbols()) - .addAll(boundSymbols) + return ImmutableSet.builder() + .addAll(source.getOutputVariables()) + .addAll(boundVariables) .build(); } } - private static void checkDependencies(Collection inputs, Collection required, String message, Object... parameters) + private static void checkDependencies(Collection inputs, Collection required, String message, Object... parameters) { checkArgument(ImmutableSet.copyOf(inputs).containsAll(required), message, parameters); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java index f3b1ec26d13d7..8690e0c48fa8e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java @@ -18,8 +18,8 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.GroupingProperty; import com.facebook.presto.spi.LocalProperty; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.LocalProperties; import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; @@ -77,15 +77,15 @@ protected Void visitPlan(PlanNode node, Void context) @Override public Void visitAggregation(AggregationNode node, Void context) { - if (node.getPreGroupedSymbols().isEmpty()) { + if (node.getPreGroupedVariables().isEmpty()) { return null; } StreamProperties properties = derivePropertiesRecursively(node.getSource(), metadata, sesstion, types, sqlParser); - List> desiredProperties = ImmutableList.of(new GroupingProperty<>(node.getPreGroupedSymbols())); - Iterator>> matchIterator = LocalProperties.match(properties.getLocalProperties(), desiredProperties).iterator(); - Optional> unsatisfiedRequirement = Iterators.getOnlyElement(matchIterator); + List> desiredProperties = ImmutableList.of(new GroupingProperty<>(node.getPreGroupedVariables())); + Iterator>> matchIterator = LocalProperties.match(properties.getLocalProperties(), desiredProperties).iterator(); + Optional> unsatisfiedRequirement = Iterators.getOnlyElement(matchIterator); checkArgument(!unsatisfiedRequirement.isPresent(), "Streaming aggregation with input not grouped on the grouping keys"); return null; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/Expressions.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/Expressions.java index 79ccb93bb138a..c2f4f546ff35d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/Expressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/Expressions.java @@ -130,4 +130,9 @@ public Void visitSpecialForm(SpecialFormExpression specialForm, Void context) return builder.build(); } + + public static VariableReferenceExpression variable(String name, Type type) + { + return new VariableReferenceExpression(name, type); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java index 8a6c2fa040a3b..4d8854a2331e7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.relational; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; @@ -26,10 +26,10 @@ private ProjectNodeUtils() {} public static boolean isIdentity(ProjectNode projectNode) { - for (Map.Entry entry : projectNode.getAssignments().entrySet()) { + for (Map.Entry entry : projectNode.getAssignments().entrySet()) { Expression expression = entry.getValue(); - Symbol symbol = entry.getKey(); - if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) { + VariableReferenceExpression variable = entry.getKey(); + if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(variable.getName()))) { return false; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java index bf3860a4d3942..cb0ed8b0e4b03 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java @@ -33,7 +33,6 @@ import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.relational.optimizer.ExpressionOptimizer; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression; @@ -149,7 +148,7 @@ private SqlToRowExpressionTranslator() {} public static RowExpression translate( Expression expression, Map, Type> types, - Map layout, + Map layout, FunctionManager functionManager, TypeManager typeManager, Session session, @@ -177,7 +176,7 @@ private static class Visitor extends AstVisitor { private final Map, Type> types; - private final Map layout; + private final Map layout; private final TypeManager typeManager; private final FunctionManager functionManager; private final Session session; @@ -185,7 +184,7 @@ private static class Visitor private Visitor( Map, Type> types, - Map layout, + Map layout, TypeManager typeManager, FunctionManager functionManager, Session session) @@ -390,12 +389,13 @@ protected RowExpression visitFunctionCall(FunctionCall node, Void context) @Override protected RowExpression visitSymbolReference(SymbolReference node, Void context) { - Integer channel = layout.get(Symbol.from(node)); + VariableReferenceExpression variable = new VariableReferenceExpression(node.getName(), getType(node)); + Integer channel = layout.get(variable); if (channel != null) { - return field(channel, getType(node)); + return field(channel, variable.getType()); } - return new VariableReferenceExpression(node.getName(), getType(node)); + return variable; } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SymbolToChannelTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/VariableToChannelTranslator.java similarity index 70% rename from presto-main/src/main/java/com/facebook/presto/sql/relational/SymbolToChannelTranslator.java rename to presto-main/src/main/java/com/facebook/presto/sql/relational/VariableToChannelTranslator.java index c3857da2e1bc3..739aced9a190f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SymbolToChannelTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/VariableToChannelTranslator.java @@ -21,38 +21,39 @@ import com.facebook.presto.spi.relation.RowExpressionVisitor; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.planner.Symbol; import com.google.common.collect.ImmutableList; import java.util.Map; import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.field; +import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.Maps.filterKeys; -public final class SymbolToChannelTranslator +public final class VariableToChannelTranslator { - private SymbolToChannelTranslator() {} + private VariableToChannelTranslator() {} /** * * Given an {@param expression} and a {@param layout}, translate the symbols in the expression to the corresponding channel. */ - public static RowExpression translate(RowExpression expression, Map layout) + public static RowExpression translate(RowExpression expression, Map layout) { return expression.accept(new Visitor(), layout); } private static class Visitor - implements RowExpressionVisitor> + implements RowExpressionVisitor> { @Override - public RowExpression visitInputReference(InputReferenceExpression input, Map layout) + public RowExpression visitInputReference(InputReferenceExpression input, Map layout) { throw new UnsupportedOperationException("encountered already-translated symbols"); } @Override - public RowExpression visitCall(CallExpression call, Map layout) + public RowExpression visitCall(CallExpression call, Map layout) { ImmutableList.Builder arguments = ImmutableList.builder(); call.getArguments().forEach(argument -> arguments.add(argument.accept(this, layout))); @@ -60,30 +61,33 @@ public RowExpression visitCall(CallExpression call, Map layout) } @Override - public RowExpression visitConstant(ConstantExpression literal, Map layout) + public RowExpression visitConstant(ConstantExpression literal, Map layout) { return literal; } @Override - public RowExpression visitLambda(LambdaDefinitionExpression lambda, Map layout) + public RowExpression visitLambda(LambdaDefinitionExpression lambda, Map layout) { return new LambdaDefinitionExpression(lambda.getArgumentTypes(), lambda.getArguments(), lambda.getBody().accept(this, layout)); } @Override - public RowExpression visitVariableReference(VariableReferenceExpression reference, Map layout) + public RowExpression visitVariableReference(VariableReferenceExpression reference, Map layout) { - Symbol symbol = new Symbol(reference.getName()); - if (layout.containsKey(symbol)) { - return field(layout.get(symbol), reference.getType()); + // We only use the variable name to find the reference in layout because SqlToRowExpression translator might optimize type cast + // to a variable with the same name as in layout but with a different type. + // TODO https://github.com/prestodb/presto/issues/12892 + Map candidate = filterKeys(layout, variable -> variable.getName().equals(reference.getName())); + if (!candidate.isEmpty()) { + return field(getOnlyElement(candidate.values()), reference.getType()); } // this is possible only for lambda return reference; } @Override - public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Map layout) + public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Map layout) { ImmutableList.Builder arguments = ImmutableList.builder(); specialForm.getArguments().forEach(argument -> arguments.add(argument.accept(this, layout))); diff --git a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index b9eebb9d205ba..e423e76bc8787 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -15,10 +15,9 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.SubPlan; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -244,14 +243,14 @@ public Void visitTableWriter(TableWriterNode node, Void context) @Override public Void visitStatisticsWriterNode(StatisticsWriterNode node, Void context) { - printNode(node, format("StatisticsWriterNode[%s]", Joiner.on(", ").join(node.getOutputSymbols())), NODE_COLORS.get(NodeType.ANALYZE_FINISH)); + printNode(node, format("StatisticsWriterNode[%s]", Joiner.on(", ").join(node.getOutputVariables())), NODE_COLORS.get(NodeType.ANALYZE_FINISH)); return node.getSource().accept(this, context); } @Override public Void visitTableFinish(TableFinishNode node, Void context) { - printNode(node, format("TableFinish[%s]", Joiner.on(", ").join(node.getOutputSymbols())), NODE_COLORS.get(NodeType.TABLE_FINISH)); + printNode(node, format("TableFinish[%s]", Joiner.on(", ").join(node.getOutputVariables())), NODE_COLORS.get(NodeType.TABLE_FINISH)); return node.getSource().accept(this, context); } @@ -272,7 +271,7 @@ public Void visitSort(SortNode node, Void context) @Override public Void visitMarkDistinct(MarkDistinctNode node, Void context) { - printNode(node, format("MarkDistinct[%s]", node.getMarkerSymbol()), format("%s => %s", node.getDistinctSymbols(), node.getMarkerSymbol()), NODE_COLORS.get(NodeType.MARK_DISTINCT)); + printNode(node, format("MarkDistinct[%s]", node.getMarkerVariable()), format("%s => %s", node.getDistinctVariables(), node.getMarkerVariable()), NODE_COLORS.get(NodeType.MARK_DISTINCT)); return node.getSource().accept(this, context); } @@ -332,13 +331,13 @@ public Void visitRemoteSource(RemoteSourceNode node, Void context) @Override public Void visitExchange(ExchangeNode node, Void context) { - List symbols = node.getOutputSymbols().stream() - .map(ArgumentBinding::columnBinding) - .collect(toImmutableList()); + String columns; if (node.getType() == REPARTITION) { - symbols = node.getPartitioningScheme().getPartitioning().getArguments(); + columns = Joiner.on(", ").join(node.getPartitioningScheme().getPartitioning().getArguments()); + } + else { + columns = Joiner.on(", ").join(node.getOutputVariables()); } - String columns = Joiner.on(", ").join(symbols); printNode(node, format("ExchangeNode[%s]", node.getType()), columns, NODE_COLORS.get(NodeType.EXCHANGE)); for (PlanNode planNode : node.getSources()) { planNode.accept(this, context); @@ -350,7 +349,7 @@ public Void visitExchange(ExchangeNode node, Void context) public Void visitAggregation(AggregationNode node, Void context) { StringBuilder builder = new StringBuilder(); - for (Map.Entry entry : node.getAggregations().entrySet()) { + for (Map.Entry entry : node.getAggregations().entrySet()) { builder.append(format("%s := %s\\n", entry.getKey(), formatAggregation(entry.getValue()))); } printNode(node, format("Aggregate[%s]", node.getStep()), builder.toString(), NODE_COLORS.get(NodeType.AGGREGATE)); @@ -393,7 +392,7 @@ public Void visitFilter(FilterNode node, Void context) public Void visitProject(ProjectNode node, Void context) { StringBuilder builder = new StringBuilder(); - for (Map.Entry entry : node.getAssignments().entrySet()) { + for (Map.Entry entry : node.getAssignments().entrySet()) { if ((entry.getValue() instanceof SymbolReference) && ((SymbolReference) entry.getValue()).getName().equals(entry.getKey().getName())) { // skip identity assignments @@ -409,11 +408,11 @@ public Void visitProject(ProjectNode node, Void context) @Override public Void visitUnnest(UnnestNode node, Void context) { - if (!node.getOrdinalitySymbol().isPresent()) { - printNode(node, format("Unnest[%s]", node.getUnnestSymbols().keySet()), NODE_COLORS.get(NodeType.UNNEST)); + if (!node.getOrdinalityVariable().isPresent()) { + printNode(node, format("Unnest[%s]", node.getUnnestVariables().keySet()), NODE_COLORS.get(NodeType.UNNEST)); } else { - printNode(node, format("Unnest[%s (ordinality)]", node.getUnnestSymbols().keySet()), NODE_COLORS.get(NodeType.UNNEST)); + printNode(node, format("Unnest[%s (ordinality)]", node.getUnnestVariables().keySet()), NODE_COLORS.get(NodeType.UNNEST)); } return node.getSource().accept(this, context); } @@ -489,7 +488,7 @@ public Void visitJoin(JoinNode node, Void context) @Override public Void visitSemiJoin(SemiJoinNode node, Void context) { - printNode(node, "SemiJoin", format("%s = %s", node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol()), NODE_COLORS.get(NodeType.JOIN)); + printNode(node, "SemiJoin", format("%s = %s", node.getSourceJoinVariable(), node.getFilteringSourceJoinVariable()), NODE_COLORS.get(NodeType.JOIN)); node.getSource().accept(this, context); node.getFilteringSource().accept(this, context); @@ -554,8 +553,8 @@ public Void visitIndexJoin(IndexJoinNode node, Void context) List joinExpressions = new ArrayList<>(); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { joinExpressions.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, - clause.getProbe().toSymbolReference(), - clause.getIndex().toSymbolReference())); + new SymbolReference(clause.getProbe().getName()), + new SymbolReference(clause.getIndex().getName()))); } String criteria = Joiner.on(" AND ").join(joinExpressions); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java b/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java index fbc48cffb61a9..eb13b10fe1571 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java @@ -13,12 +13,13 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableSet; import java.util.function.Consumer; import static com.facebook.presto.cost.EstimateAssertion.assertEstimateEquals; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.google.common.collect.Sets.union; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -49,26 +50,21 @@ public PlanNodeStatsAssertion outputRowsCountUnknown() return this; } - public PlanNodeStatsAssertion symbolStats(String symbolName, Consumer symbolStatsAssertionConsumer) + public PlanNodeStatsAssertion variableStats(VariableReferenceExpression variable, Consumer columnAssertionConsumer) { - return symbolStats(new Symbol(symbolName), symbolStatsAssertionConsumer); - } - - public PlanNodeStatsAssertion symbolStats(Symbol symbol, Consumer columnAssertionConsumer) - { - SymbolStatsAssertion columnAssertion = SymbolStatsAssertion.assertThat(actual.getSymbolStatistics(symbol)); + VariableStatsAssertion columnAssertion = VariableStatsAssertion.assertThat(actual.getVariableStatistics(variable)); columnAssertionConsumer.accept(columnAssertion); return this; } - public PlanNodeStatsAssertion symbolStatsUnknown(String symbolName) + public PlanNodeStatsAssertion variableStatsUnknown(String symbolName) { - return symbolStatsUnknown(new Symbol(symbolName)); + return variableStatsUnknown(new VariableReferenceExpression(symbolName, BIGINT)); } - public PlanNodeStatsAssertion symbolStatsUnknown(Symbol symbol) + public PlanNodeStatsAssertion variableStatsUnknown(VariableReferenceExpression variable) { - return symbolStats(symbol, + return variableStats(variable, columnStats -> columnStats .lowValueUnknown() .highValueUnknown() @@ -76,9 +72,9 @@ public PlanNodeStatsAssertion symbolStatsUnknown(Symbol symbol) .distinctValuesCountUnknown()); } - public PlanNodeStatsAssertion symbolsWithKnownStats(Symbol... symbols) + public PlanNodeStatsAssertion variablesWithKnownStats(VariableReferenceExpression... variable) { - assertEquals(actual.getSymbolsWithKnownStatistics(), ImmutableSet.copyOf(symbols), "symbols with known stats"); + assertEquals(actual.getVariablesWithKnownStatistics(), ImmutableSet.copyOf(variable), "variables with known stats"); return this; } @@ -86,18 +82,18 @@ public PlanNodeStatsAssertion equalTo(PlanNodeStatsEstimate expected) { assertEstimateEquals(actual.getOutputRowCount(), expected.getOutputRowCount(), "outputRowCount mismatch"); - for (Symbol symbol : union(expected.getSymbolsWithKnownStatistics(), actual.getSymbolsWithKnownStatistics())) { - assertSymbolStatsEqual(symbol, actual.getSymbolStatistics(symbol), expected.getSymbolStatistics(symbol)); + for (VariableReferenceExpression variable : union(expected.getVariablesWithKnownStatistics(), actual.getVariablesWithKnownStatistics())) { + assertVariableStatsEqual(variable, actual.getVariableStatistics(variable), expected.getVariableStatistics(variable)); } return this; } - private void assertSymbolStatsEqual(Symbol symbol, SymbolStatsEstimate actual, SymbolStatsEstimate expected) + private void assertVariableStatsEqual(VariableReferenceExpression variable, VariableStatsEstimate actual, VariableStatsEstimate expected) { - assertEstimateEquals(actual.getNullsFraction(), expected.getNullsFraction(), "nullsFraction mismatch for %s", symbol.getName()); - assertEstimateEquals(actual.getLowValue(), expected.getLowValue(), "lowValue mismatch for %s", symbol.getName()); - assertEstimateEquals(actual.getHighValue(), expected.getHighValue(), "highValue mismatch for %s", symbol.getName()); - assertEstimateEquals(actual.getDistinctValuesCount(), expected.getDistinctValuesCount(), "distinct values count mismatch for %s", symbol.getName()); - assertEstimateEquals(actual.getAverageRowSize(), expected.getAverageRowSize(), "average row size mismatch for %s", symbol.getName()); + assertEstimateEquals(actual.getNullsFraction(), expected.getNullsFraction(), "nullsFraction mismatch for %s", variable.getName()); + assertEstimateEquals(actual.getLowValue(), expected.getLowValue(), "lowValue mismatch for %s", variable.getName()); + assertEstimateEquals(actual.getHighValue(), expected.getHighValue(), "highValue mismatch for %s", variable.getName()); + assertEstimateEquals(actual.getDistinctValuesCount(), expected.getDistinctValuesCount(), "distinct values count mismatch for %s", variable.getName()); + assertEstimateEquals(actual.getAverageRowSize(), expected.getAverageRowSize(), "average row size mismatch for %s", variable.getName()); } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java index 042ccb2231bbb..918b9f40208e7 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -30,19 +30,19 @@ public void testAggregationWhenAllStatisticsAreKnown() { Consumer outputRowCountAndZStatsAreCalculated = check -> check .outputRowsCount(15) - .symbolStats("z", symbolStatsAssertion -> symbolStatsAssertion + .variableStats(new VariableReferenceExpression("z", BIGINT), symbolStatsAssertion -> symbolStatsAssertion .lowValue(10) .highValue(15) .distinctValuesCount(4) .nullsFraction(0.2)) - .symbolStats("y", symbolStatsAssertion -> symbolStatsAssertion + .variableStats(new VariableReferenceExpression("y", BIGINT), symbolStatsAssertion -> symbolStatsAssertion .lowValue(0) .highValue(3) .distinctValuesCount(3) .nullsFraction(0)); testAggregation( - SymbolStatsEstimate.builder() + VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) @@ -51,7 +51,7 @@ public void testAggregationWhenAllStatisticsAreKnown() .check(outputRowCountAndZStatsAreCalculated); testAggregation( - SymbolStatsEstimate.builder() + VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) @@ -60,17 +60,17 @@ public void testAggregationWhenAllStatisticsAreKnown() Consumer outputRowsCountAndZStatsAreNotFullyCalculated = check -> check .outputRowsCountUnknown() - .symbolStats("z", symbolStatsAssertion -> symbolStatsAssertion + .variableStats(new VariableReferenceExpression("z", BIGINT), symbolStatsAssertion -> symbolStatsAssertion .unknownRange() .distinctValuesCountUnknown() .nullsFractionUnknown()) - .symbolStats("y", symbolStatsAssertion -> symbolStatsAssertion + .variableStats(new VariableReferenceExpression("y", BIGINT), symbolStatsAssertion -> symbolStatsAssertion .unknownRange() .nullsFractionUnknown() .distinctValuesCountUnknown()); testAggregation( - SymbolStatsEstimate.builder() + VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setNullsFraction(0.1) @@ -78,55 +78,55 @@ public void testAggregationWhenAllStatisticsAreKnown() .check(outputRowsCountAndZStatsAreNotFullyCalculated); testAggregation( - SymbolStatsEstimate.builder() + VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .build()) .check(outputRowsCountAndZStatsAreNotFullyCalculated); } - private StatsCalculatorAssertion testAggregation(SymbolStatsEstimate zStats) + private StatsCalculatorAssertion testAggregation(VariableStatsEstimate zStats) { return tester().assertStatsFor(pb -> pb .aggregation(ab -> ab - .addAggregation(pb.symbol("sum", BIGINT), expression("sum(x)"), ImmutableList.of(BIGINT)) - .addAggregation(pb.symbol("count", BIGINT), expression("count()"), ImmutableList.of()) - .addAggregation(pb.symbol("count_on_x", BIGINT), expression("count(x)"), ImmutableList.of(BIGINT)) - .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) - .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) + .addAggregation(pb.variable(pb.symbol("sum", BIGINT)), expression("sum(x)"), ImmutableList.of(BIGINT)) + .addAggregation(pb.variable(pb.symbol("count", BIGINT)), expression("count()"), ImmutableList.of()) + .addAggregation(pb.variable(pb.symbol("count_on_x", BIGINT)), expression("count(x)"), ImmutableList.of(BIGINT)) + .singleGroupingSet(pb.variable("y", BIGINT), pb.variable("z", BIGINT)) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT), pb.variable("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), VariableStatsEstimate.builder() .setLowValue(1) .setHighValue(10) .setDistinctValuesCount(5) .setNullsFraction(0.3) .build()) - .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(3) .setDistinctValuesCount(3) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("z"), zStats) + .addVariableStatistics(new VariableReferenceExpression("z", BIGINT), zStats) .build()) .check(check -> check - .symbolStats("sum", symbolStatsAssertion -> symbolStatsAssertion + .variableStats(new VariableReferenceExpression("sum", BIGINT), symbolStatsAssertion -> symbolStatsAssertion .lowValueUnknown() .highValueUnknown() .distinctValuesCountUnknown() .nullsFractionUnknown()) - .symbolStats("count", symbolStatsAssertion -> symbolStatsAssertion + .variableStats(new VariableReferenceExpression("count", BIGINT), symbolStatsAssertion -> symbolStatsAssertion .lowValueUnknown() .highValueUnknown() .distinctValuesCountUnknown() .nullsFractionUnknown()) - .symbolStats("count_on_x", symbolStatsAssertion -> symbolStatsAssertion + .variableStats(new VariableReferenceExpression("count_on_x", BIGINT), symbolStatsAssertion -> symbolStatsAssertion .lowValueUnknown() .highValueUnknown() .distinctValuesCountUnknown() .nullsFractionUnknown()) - .symbolStats("x", symbolStatsAssertion -> symbolStatsAssertion + .variableStats(new VariableReferenceExpression("x", BIGINT), symbolStatsAssertion -> symbolStatsAssertion .lowValueUnknown() .highValueUnknown() .distinctValuesCountUnknown() @@ -138,13 +138,13 @@ public void testAggregationStatsCappedToInputRows() { tester().assertStatsFor(pb -> pb .aggregation(ab -> ab - .addAggregation(pb.symbol("count_on_x", BIGINT), expression("count(x)"), ImmutableList.of(BIGINT)) - .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) - .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) + .addAggregation(pb.variable(pb.symbol("count_on_x", BIGINT)), expression("count(x)"), ImmutableList.of(BIGINT)) + .singleGroupingSet(pb.variable("y", BIGINT), pb.variable("z", BIGINT)) + .source(pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT), pb.variable("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build()) - .addSymbolStatistics(new Symbol("z"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build()) + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), VariableStatsEstimate.builder().setDistinctValuesCount(50).build()) + .addVariableStatistics(new VariableReferenceExpression("z", BIGINT), VariableStatsEstimate.builder().setDistinctValuesCount(50).build()) .build()) .check(check -> check.outputRowsCount(100)); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java index d0586f1390fa7..4ace1f40e4fd2 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java @@ -15,7 +15,8 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.spi.type.DoubleType; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.sql.planner.Symbol; @@ -37,7 +38,7 @@ import java.util.Objects; import java.util.function.Consumer; -import static com.facebook.presto.spi.type.StandardTypes.BIGINT; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.LESS_THAN; @@ -57,16 +58,16 @@ public class TestComparisonStatsCalculator private Session session; private PlanNodeStatsEstimate standardInputStatistics; private TypeProvider types; - private SymbolStatsEstimate uStats; - private SymbolStatsEstimate wStats; - private SymbolStatsEstimate xStats; - private SymbolStatsEstimate yStats; - private SymbolStatsEstimate zStats; - private SymbolStatsEstimate leftOpenStats; - private SymbolStatsEstimate rightOpenStats; - private SymbolStatsEstimate unknownRangeStats; - private SymbolStatsEstimate emptyRangeStats; - private SymbolStatsEstimate varcharStats; + private VariableStatsEstimate uStats; + private VariableStatsEstimate wStats; + private VariableStatsEstimate xStats; + private VariableStatsEstimate yStats; + private VariableStatsEstimate zStats; + private VariableStatsEstimate leftOpenStats; + private VariableStatsEstimate rightOpenStats; + private VariableStatsEstimate unknownRangeStats; + private VariableStatsEstimate emptyRangeStats; + private VariableStatsEstimate varcharStats; @BeforeClass public void setUp() @@ -76,70 +77,70 @@ public void setUp() MetadataManager metadata = MetadataManager.createTestMetadataManager(); filterStatsCalculator = new FilterStatsCalculator(metadata, new ScalarStatsCalculator(metadata), new StatsNormalizer()); - uStats = SymbolStatsEstimate.builder() + uStats = VariableStatsEstimate.builder() .setAverageRowSize(8.0) .setDistinctValuesCount(300) .setLowValue(0) .setHighValue(20) .setNullsFraction(0.1) .build(); - wStats = SymbolStatsEstimate.builder() + wStats = VariableStatsEstimate.builder() .setAverageRowSize(8.0) .setDistinctValuesCount(30) .setLowValue(0) .setHighValue(20) .setNullsFraction(0.1) .build(); - xStats = SymbolStatsEstimate.builder() + xStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(40.0) .setLowValue(-10.0) .setHighValue(10.0) .setNullsFraction(0.25) .build(); - yStats = SymbolStatsEstimate.builder() + yStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(20.0) .setLowValue(0.0) .setHighValue(5.0) .setNullsFraction(0.5) .build(); - zStats = SymbolStatsEstimate.builder() + zStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(5.0) .setLowValue(-100.0) .setHighValue(100.0) .setNullsFraction(0.1) .build(); - leftOpenStats = SymbolStatsEstimate.builder() + leftOpenStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(NEGATIVE_INFINITY) .setHighValue(15.0) .setNullsFraction(0.1) .build(); - rightOpenStats = SymbolStatsEstimate.builder() + rightOpenStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(-15.0) .setHighValue(POSITIVE_INFINITY) .setNullsFraction(0.1) .build(); - unknownRangeStats = SymbolStatsEstimate.builder() + unknownRangeStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(NEGATIVE_INFINITY) .setHighValue(POSITIVE_INFINITY) .setNullsFraction(0.1) .build(); - emptyRangeStats = SymbolStatsEstimate.builder() + emptyRangeStats = VariableStatsEstimate.builder() .setAverageRowSize(0.0) .setDistinctValuesCount(0.0) .setLowValue(NaN) .setHighValue(NaN) .setNullsFraction(1.0) .build(); - varcharStats = SymbolStatsEstimate.builder() + varcharStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(NEGATIVE_INFINITY) @@ -147,34 +148,34 @@ public void setUp() .setNullsFraction(0.1) .build(); standardInputStatistics = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("u"), uStats) - .addSymbolStatistics(new Symbol("w"), wStats) - .addSymbolStatistics(new Symbol("x"), xStats) - .addSymbolStatistics(new Symbol("y"), yStats) - .addSymbolStatistics(new Symbol("z"), zStats) - .addSymbolStatistics(new Symbol("leftOpen"), leftOpenStats) - .addSymbolStatistics(new Symbol("rightOpen"), rightOpenStats) - .addSymbolStatistics(new Symbol("unknownRange"), unknownRangeStats) - .addSymbolStatistics(new Symbol("emptyRange"), emptyRangeStats) - .addSymbolStatistics(new Symbol("varchar"), varcharStats) + .addVariableStatistics(new VariableReferenceExpression("u", DOUBLE), uStats) + .addVariableStatistics(new VariableReferenceExpression("w", DOUBLE), wStats) + .addVariableStatistics(new VariableReferenceExpression("x", DOUBLE), xStats) + .addVariableStatistics(new VariableReferenceExpression("y", DOUBLE), yStats) + .addVariableStatistics(new VariableReferenceExpression("z", DOUBLE), zStats) + .addVariableStatistics(new VariableReferenceExpression("leftOpen", DOUBLE), leftOpenStats) + .addVariableStatistics(new VariableReferenceExpression("rightOpen", DOUBLE), rightOpenStats) + .addVariableStatistics(new VariableReferenceExpression("unknownRange", DOUBLE), unknownRangeStats) + .addVariableStatistics(new VariableReferenceExpression("emptyRange", DOUBLE), emptyRangeStats) + .addVariableStatistics(new VariableReferenceExpression("varchar", VarcharType.createVarcharType(10)), varcharStats) .setOutputRowCount(1000.0) .build(); types = TypeProvider.copyOf(ImmutableMap.builder() - .put(new Symbol("u"), DoubleType.DOUBLE) - .put(new Symbol("w"), DoubleType.DOUBLE) - .put(new Symbol("x"), DoubleType.DOUBLE) - .put(new Symbol("y"), DoubleType.DOUBLE) - .put(new Symbol("z"), DoubleType.DOUBLE) - .put(new Symbol("leftOpen"), DoubleType.DOUBLE) - .put(new Symbol("rightOpen"), DoubleType.DOUBLE) - .put(new Symbol("unknownRange"), DoubleType.DOUBLE) - .put(new Symbol("emptyRange"), DoubleType.DOUBLE) + .put(new Symbol("u"), DOUBLE) + .put(new Symbol("w"), DOUBLE) + .put(new Symbol("x"), DOUBLE) + .put(new Symbol("y"), DOUBLE) + .put(new Symbol("z"), DOUBLE) + .put(new Symbol("leftOpen"), DOUBLE) + .put(new Symbol("rightOpen"), DOUBLE) + .put(new Symbol("unknownRange"), DOUBLE) + .put(new Symbol("emptyRange"), DOUBLE) .put(new Symbol("varchar"), VarcharType.createVarcharType(10)) .build()); } - private Consumer equalTo(SymbolStatsEstimate estimate) + private Consumer equalTo(VariableStatsEstimate estimate) { return symbolAssert -> { symbolAssert @@ -185,12 +186,12 @@ private Consumer equalTo(SymbolStatsEstimate estimate) }; } - private SymbolStatsEstimate updateNDV(SymbolStatsEstimate symbolStats, double delta) + private VariableStatsEstimate updateNDV(VariableStatsEstimate symbolStats, double delta) { return symbolStats.mapDistinctValuesCount(ndv -> ndv + delta); } - private SymbolStatsEstimate capNDV(SymbolStatsEstimate symbolStats, double rowCount) + private VariableStatsEstimate capNDV(VariableStatsEstimate symbolStats, double rowCount) { double ndv = symbolStats.getDistinctValuesCount(); double nulls = symbolStats.getNullsFraction(); @@ -205,7 +206,7 @@ private SymbolStatsEstimate capNDV(SymbolStatsEstimate symbolStats, double rowCo .mapNullsFraction(n -> nulls / 2); } - private SymbolStatsEstimate zeroNullsFraction(SymbolStatsEstimate symbolStats) + private VariableStatsEstimate zeroNullsFraction(VariableStatsEstimate symbolStats) { return symbolStats.mapNullsFraction(fraction -> 0.0); } @@ -223,7 +224,7 @@ public void verifyTestInputConsistent() new StatsNormalizer(), "standardInputStatistics", standardInputStatistics, - standardInputStatistics.getSymbolsWithKnownStatistics(), + standardInputStatistics.getVariablesWithKnownStatistics(), types); } @@ -233,7 +234,7 @@ public void symbolToLiteralEqualStats() // Simple case assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("y"), new DoubleLiteral("2.5"))) .outputRowsCount(25.0) // all rows minus nulls divided by distinct values count - .symbolStats("y", symbolAssert -> { + .variableStats(new VariableReferenceExpression("y", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(1.0) .lowValue(2.5) @@ -244,7 +245,7 @@ public void symbolToLiteralEqualStats() // Literal on the edge of symbol range assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("x"), new DoubleLiteral("10.0"))) .outputRowsCount(18.75) // all rows minus nulls divided by distinct values count - .symbolStats("x", symbolAssert -> { + .variableStats(new VariableReferenceExpression("x", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(1.0) .lowValue(10.0) @@ -255,7 +256,7 @@ public void symbolToLiteralEqualStats() // Literal out of symbol range assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("y"), new DoubleLiteral("10.0"))) .outputRowsCount(0.0) // all rows minus nulls divided by distinct values count - .symbolStats("y", symbolAssert -> { + .variableStats(new VariableReferenceExpression("y", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(0.0) .distinctValuesCount(0.0) .emptyRange() @@ -265,7 +266,7 @@ public void symbolToLiteralEqualStats() // Literal in left open range assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("leftOpen"), new DoubleLiteral("2.5"))) .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count - .symbolStats("leftOpen", symbolAssert -> { + .variableStats(new VariableReferenceExpression("leftOpen", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(1.0) .lowValue(2.5) @@ -276,7 +277,7 @@ public void symbolToLiteralEqualStats() // Literal in right open range assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("rightOpen"), new DoubleLiteral("-2.5"))) .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count - .symbolStats("rightOpen", symbolAssert -> { + .variableStats(new VariableReferenceExpression("rightOpen", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(1.0) .lowValue(-2.5) @@ -287,7 +288,7 @@ public void symbolToLiteralEqualStats() // Literal in unknown range assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))) .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count - .symbolStats("unknownRange", symbolAssert -> { + .variableStats(new VariableReferenceExpression("unknownRange", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(1.0) .lowValue(0.0) @@ -298,12 +299,12 @@ public void symbolToLiteralEqualStats() // Literal in empty range assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))) .outputRowsCount(0.0) - .symbolStats("emptyRange", equalTo(emptyRangeStats)); + .variableStats(new VariableReferenceExpression("emptyRange", DOUBLE), equalTo(emptyRangeStats)); // Column with values not representable as double (unknown range) assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("varchar"), new StringLiteral("blah"))) .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count - .symbolStats("varchar", symbolAssert -> { + .variableStats(new VariableReferenceExpression("varchar", VarcharType.createVarcharType(10)), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(1.0) .lowValue(NEGATIVE_INFINITY) @@ -318,7 +319,7 @@ public void symbolToLiteralNotEqualStats() // Simple case assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("y"), new DoubleLiteral("2.5"))) .outputRowsCount(475.0) // all rows minus nulls multiplied by ((distinct values - 1) / distinct values) - .symbolStats("y", symbolAssert -> { + .variableStats(new VariableReferenceExpression("y", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(19.0) .lowValue(0.0) @@ -329,7 +330,7 @@ public void symbolToLiteralNotEqualStats() // Literal on the edge of symbol range assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("x"), new DoubleLiteral("10.0"))) .outputRowsCount(731.25) // all rows minus nulls multiplied by ((distinct values - 1) / distinct values) - .symbolStats("x", symbolAssert -> { + .variableStats(new VariableReferenceExpression("x", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(39.0) .lowValue(-10.0) @@ -340,7 +341,7 @@ public void symbolToLiteralNotEqualStats() // Literal out of symbol range assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("y"), new DoubleLiteral("10.0"))) .outputRowsCount(500.0) // all rows minus nulls - .symbolStats("y", symbolAssert -> { + .variableStats(new VariableReferenceExpression("y", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(19.0) .lowValue(0.0) @@ -351,7 +352,7 @@ public void symbolToLiteralNotEqualStats() // Literal in left open range assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("leftOpen"), new DoubleLiteral("2.5"))) .outputRowsCount(882.0) // all rows minus nulls multiplied by ((distinct values - 1) / distinct values) - .symbolStats("leftOpen", symbolAssert -> { + .variableStats(new VariableReferenceExpression("leftOpen", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(49.0) .lowValueUnknown() @@ -362,7 +363,7 @@ public void symbolToLiteralNotEqualStats() // Literal in right open range assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("rightOpen"), new DoubleLiteral("-2.5"))) .outputRowsCount(882.0) // all rows minus nulls divided by distinct values count - .symbolStats("rightOpen", symbolAssert -> { + .variableStats(new VariableReferenceExpression("rightOpen", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(49.0) .lowValue(-15.0) @@ -373,7 +374,7 @@ public void symbolToLiteralNotEqualStats() // Literal in unknown range assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))) .outputRowsCount(882.0) // all rows minus nulls divided by distinct values count - .symbolStats("unknownRange", symbolAssert -> { + .variableStats(new VariableReferenceExpression("unknownRange", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(49.0) .lowValueUnknown() @@ -384,12 +385,12 @@ public void symbolToLiteralNotEqualStats() // Literal in empty range assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))) .outputRowsCount(0.0) - .symbolStats("emptyRange", equalTo(emptyRangeStats)); + .variableStats(new VariableReferenceExpression("emptyRange", DOUBLE), equalTo(emptyRangeStats)); // Column with values not representable as double (unknown range) assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("varchar"), new StringLiteral("blah"))) .outputRowsCount(882.0) // all rows minus nulls divided by distinct values count - .symbolStats("varchar", symbolAssert -> { + .variableStats(new VariableReferenceExpression("varchar", VarcharType.createVarcharType(10)), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(49.0) .lowValueUnknown() @@ -404,7 +405,7 @@ public void symbolToLiteralLessThanStats() // Simple case assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("y"), new DoubleLiteral("2.5"))) .outputRowsCount(250.0) // all rows minus nulls times range coverage (50%) - .symbolStats("y", symbolAssert -> { + .variableStats(new VariableReferenceExpression("y", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(10.0) .lowValue(0.0) @@ -415,7 +416,7 @@ public void symbolToLiteralLessThanStats() // Literal on the edge of symbol range (whole range included) assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new DoubleLiteral("10.0"))) .outputRowsCount(750.0) // all rows minus nulls times range coverage (100%) - .symbolStats("x", symbolAssert -> { + .variableStats(new VariableReferenceExpression("x", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(40.0) .lowValue(-10.0) @@ -426,7 +427,7 @@ public void symbolToLiteralLessThanStats() // Literal on the edge of symbol range (whole range excluded) assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new DoubleLiteral("-10.0"))) .outputRowsCount(18.75) // all rows minus nulls divided by NDV (one value from edge is included as approximation) - .symbolStats("x", symbolAssert -> { + .variableStats(new VariableReferenceExpression("x", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(1.0) .lowValue(-10.0) @@ -437,7 +438,7 @@ public void symbolToLiteralLessThanStats() // Literal range out of symbol range assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("y"), new DoubleLiteral("-10.0"))) .outputRowsCount(0.0) // all rows minus nulls times range coverage (0%) - .symbolStats("y", symbolAssert -> { + .variableStats(new VariableReferenceExpression("y", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(0.0) .distinctValuesCount(0.0) .emptyRange() @@ -447,7 +448,7 @@ public void symbolToLiteralLessThanStats() // Literal in left open range assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("leftOpen"), new DoubleLiteral("0.0"))) .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) - .symbolStats("leftOpen", symbolAssert -> { + .variableStats(new VariableReferenceExpression("leftOpen", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(25.0) //(50% heuristic) .lowValueUnknown() @@ -458,7 +459,7 @@ public void symbolToLiteralLessThanStats() // Literal in right open range assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("rightOpen"), new DoubleLiteral("0.0"))) .outputRowsCount(225.0) // all rows minus nulls times range coverage (25% - heuristic) - .symbolStats("rightOpen", symbolAssert -> { + .variableStats(new VariableReferenceExpression("rightOpen", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(12.5) //(25% heuristic) .lowValue(-15.0) @@ -469,7 +470,7 @@ public void symbolToLiteralLessThanStats() // Literal in unknown range assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))) .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) - .symbolStats("unknownRange", symbolAssert -> { + .variableStats(new VariableReferenceExpression("unknownRange", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(25.0) // (50% heuristic) .lowValueUnknown() @@ -480,7 +481,7 @@ public void symbolToLiteralLessThanStats() // Literal in empty range assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))) .outputRowsCount(0.0) - .symbolStats("emptyRange", equalTo(emptyRangeStats)); + .variableStats(new VariableReferenceExpression("emptyRange", DOUBLE), equalTo(emptyRangeStats)); } @Test @@ -489,7 +490,7 @@ public void symbolToLiteralGreaterThanStats() // Simple case assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("y"), new DoubleLiteral("2.5"))) .outputRowsCount(250.0) // all rows minus nulls times range coverage (50%) - .symbolStats("y", symbolAssert -> { + .variableStats(new VariableReferenceExpression("y", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(10.0) .lowValue(2.5) @@ -500,7 +501,7 @@ public void symbolToLiteralGreaterThanStats() // Literal on the edge of symbol range (whole range included) assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("x"), new DoubleLiteral("-10.0"))) .outputRowsCount(750.0) // all rows minus nulls times range coverage (100%) - .symbolStats("x", symbolAssert -> { + .variableStats(new VariableReferenceExpression("x", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(40.0) .lowValue(-10.0) @@ -511,7 +512,7 @@ public void symbolToLiteralGreaterThanStats() // Literal on the edge of symbol range (whole range excluded) assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("x"), new DoubleLiteral("10.0"))) .outputRowsCount(18.75) // all rows minus nulls divided by NDV (one value from edge is included as approximation) - .symbolStats("x", symbolAssert -> { + .variableStats(new VariableReferenceExpression("x", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(1.0) .lowValue(10.0) @@ -522,7 +523,7 @@ public void symbolToLiteralGreaterThanStats() // Literal range out of symbol range assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("y"), new DoubleLiteral("10.0"))) .outputRowsCount(0.0) // all rows minus nulls times range coverage (0%) - .symbolStats("y", symbolAssert -> { + .variableStats(new VariableReferenceExpression("y", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(0.0) .distinctValuesCount(0.0) .emptyRange() @@ -532,7 +533,7 @@ public void symbolToLiteralGreaterThanStats() // Literal in left open range assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("leftOpen"), new DoubleLiteral("0.0"))) .outputRowsCount(225.0) // all rows minus nulls times range coverage (25% - heuristic) - .symbolStats("leftOpen", symbolAssert -> { + .variableStats(new VariableReferenceExpression("leftOpen", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(12.5) //(25% heuristic) .lowValue(0.0) @@ -543,7 +544,7 @@ public void symbolToLiteralGreaterThanStats() // Literal in right open range assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("rightOpen"), new DoubleLiteral("0.0"))) .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) - .symbolStats("rightOpen", symbolAssert -> { + .variableStats(new VariableReferenceExpression("rightOpen", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(25.0) //(50% heuristic) .lowValue(0.0) @@ -554,7 +555,7 @@ public void symbolToLiteralGreaterThanStats() // Literal in unknown range assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))) .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) - .symbolStats("unknownRange", symbolAssert -> { + .variableStats(new VariableReferenceExpression("unknownRange", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4.0) .distinctValuesCount(25.0) // (50% heuristic) .lowValue(0.0) @@ -565,7 +566,7 @@ public void symbolToLiteralGreaterThanStats() // Literal in empty range assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))) .outputRowsCount(0.0) - .symbolStats("emptyRange", equalTo(emptyRangeStats)); + .variableStats(new VariableReferenceExpression("emptyRange", DOUBLE), equalTo(emptyRangeStats)); } @Test @@ -576,69 +577,69 @@ public void symbolToSymbolEqualStats() double rowCount = 2.7; assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("u"), new SymbolReference("w"))) .outputRowsCount(rowCount) - .symbolStats("u", equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) - .symbolStats("w", equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("u", DOUBLE), equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) + .variableStats(new VariableReferenceExpression("w", DOUBLE), equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); // One symbol's range is within the other's rowCount = 9.375; assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("x"), new SymbolReference("y"))) .outputRowsCount(rowCount) - .symbolStats("x", symbolAssert -> { + .variableStats(new VariableReferenceExpression("x", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4) .lowValue(0) .highValue(5) .distinctValuesCount(9.375 /* min(rowCount, ndv in intersection */) .nullsFraction(0); }) - .symbolStats("y", symbolAssert -> { + .variableStats(new VariableReferenceExpression("y", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(4) .lowValue(0) .highValue(5) .distinctValuesCount(9.375 /* min(rowCount, ndv in intersection */) .nullsFraction(0); }) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); // Partially overlapping ranges rowCount = 16.875; assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("x"), new SymbolReference("w"))) .outputRowsCount(rowCount) - .symbolStats("x", symbolAssert -> { + .variableStats(new VariableReferenceExpression("x", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(6) .lowValue(0) .highValue(10) .distinctValuesCount(16.875 /* min(rowCount, ndv in intersection */) .nullsFraction(0); }) - .symbolStats("w", symbolAssert -> { + .variableStats(new VariableReferenceExpression("w", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(6) .lowValue(0) .highValue(10) .distinctValuesCount(16.875 /* min(rowCount, ndv in intersection */) .nullsFraction(0); }) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); // None of the ranges is included in the other, and one symbol has much higher cardinality, so that it has bigger NDV in intersect than the other in total rowCount = 2.25; assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("x"), new SymbolReference("u"))) .outputRowsCount(rowCount) - .symbolStats("x", symbolAssert -> { + .variableStats(new VariableReferenceExpression("x", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(6) .lowValue(0) .highValue(10) .distinctValuesCount(2.25 /* min(rowCount, ndv in intersection */) .nullsFraction(0); }) - .symbolStats("u", symbolAssert -> { + .variableStats(new VariableReferenceExpression("u", DOUBLE), symbolAssert -> { symbolAssert.averageRowSize(6) .lowValue(0) .highValue(10) .distinctValuesCount(2.25 /* min(rowCount, ndv in intersection */) .nullsFraction(0); }) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); } @Test @@ -648,55 +649,55 @@ public void symbolToSymbolNotEqual() double rowCount = 807.3; assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("u"), new SymbolReference("w"))) .outputRowsCount(rowCount) - .symbolStats("u", equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) - .symbolStats("w", equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("u", DOUBLE), equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) + .variableStats(new VariableReferenceExpression("w", DOUBLE), equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); // One symbol's range is within the other's rowCount = 365.625; assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("x"), new SymbolReference("y"))) .outputRowsCount(rowCount) - .symbolStats("x", equalTo(capNDV(zeroNullsFraction(xStats), rowCount))) - .symbolStats("y", equalTo(capNDV(zeroNullsFraction(yStats), rowCount))) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("x", DOUBLE), equalTo(capNDV(zeroNullsFraction(xStats), rowCount))) + .variableStats(new VariableReferenceExpression("y", DOUBLE), equalTo(capNDV(zeroNullsFraction(yStats), rowCount))) + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); // Partially overlapping ranges rowCount = 658.125; assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("x"), new SymbolReference("w"))) .outputRowsCount(rowCount) - .symbolStats("x", equalTo(capNDV(zeroNullsFraction(xStats), rowCount))) - .symbolStats("w", equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("x", DOUBLE), equalTo(capNDV(zeroNullsFraction(xStats), rowCount))) + .variableStats(new VariableReferenceExpression("w", DOUBLE), equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); // None of the ranges is included in the other, and one symbol has much higher cardinality, so that it has bigger NDV in intersect than the other in total rowCount = 672.75; assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("x"), new SymbolReference("u"))) .outputRowsCount(rowCount) - .symbolStats("x", equalTo(capNDV(zeroNullsFraction(xStats), rowCount))) - .symbolStats("u", equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("x", DOUBLE), equalTo(capNDV(zeroNullsFraction(xStats), rowCount))) + .variableStats(new VariableReferenceExpression("u", DOUBLE), equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); } @Test public void symbolToCastExpressionNotEqual() { double rowCount = 807.3; - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("u"), new Cast(new SymbolReference("w"), BIGINT))) + assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("u"), new Cast(new SymbolReference("w"), StandardTypes.BIGINT))) .outputRowsCount(rowCount) - .symbolStats("u", equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) - .symbolStats("w", equalTo(capNDV(wStats, rowCount))) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("u", DOUBLE), equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) + .variableStats(new VariableReferenceExpression("w", DOUBLE), equalTo(capNDV(wStats, rowCount))) + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); rowCount = 897.0; - assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("u"), new Cast(new LongLiteral("10"), BIGINT))) + assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("u"), new Cast(new LongLiteral("10"), StandardTypes.BIGINT))) .outputRowsCount(rowCount) - .symbolStats("u", equalTo(capNDV(updateNDV(zeroNullsFraction(uStats), -1), rowCount))) - .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + .variableStats(new VariableReferenceExpression("u", DOUBLE), equalTo(capNDV(updateNDV(zeroNullsFraction(uStats), -1), rowCount))) + .variableStats(new VariableReferenceExpression("z", DOUBLE), equalTo(capNDV(zStats, rowCount))); } - private static void checkConsistent(StatsNormalizer normalizer, String source, PlanNodeStatsEstimate stats, Collection outputSymbols, TypeProvider types) + private static void checkConsistent(StatsNormalizer normalizer, String source, PlanNodeStatsEstimate stats, Collection outputVariables, TypeProvider types) { - PlanNodeStatsEstimate normalized = normalizer.normalize(stats, outputSymbols, types); + PlanNodeStatsEstimate normalized = normalizer.normalize(stats, outputVariables); if (Objects.equals(stats, normalized)) { return; } @@ -710,13 +711,13 @@ private static void checkConsistent(StatsNormalizer normalizer, String source, P normalized.getOutputRowCount())); } - for (Symbol symbol : stats.getSymbolsWithKnownStatistics()) { - if (!Objects.equals(stats.getSymbolStatistics(symbol), normalized.getSymbolStatistics(symbol))) { + for (VariableReferenceExpression variable : stats.getVariablesWithKnownStatistics()) { + if (!Objects.equals(stats.getVariableStatistics(variable), normalized.getVariableStatistics(variable))) { problems.add(format( - "Symbol stats for '%s' are \n\t\t\t\t\t%s, should be normalized to \n\t\t\t\t\t%s", - symbol, - stats.getSymbolStatistics(symbol), - normalized.getSymbolStatistics(symbol))); + "Variable stats for '%s' are \n\t\t\t\t\t%s, should be normalized to \n\t\t\t\t\t%s", + variable, + stats.getVariableStatistics(variable), + normalized.getVariableStatistics(variable))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java index ed98d1ce792d3..ef7bd83d7fcf6 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -30,6 +30,7 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.parser.SqlParser; @@ -176,7 +177,7 @@ public void testTableScan() public void testProject() { TableScanNode tableScan = tableScan("ts", "orderkey"); - PlanNode project = project("project", tableScan, "string", new Cast(new SymbolReference("orderkey"), "VARCHAR")); + PlanNode project = project("project", tableScan, new VariableReferenceExpression("string", VARCHAR), new Cast(new SymbolReference("orderkey"), "VARCHAR")); Map costs = ImmutableMap.of("ts", cpuCost(1000)); Map stats = ImmutableMap.of( "project", statsEstimate(project, 4000), @@ -416,11 +417,26 @@ public void testAggregation() @Test public void testRepartitionedJoinWithExchange() { - TableScanNode ts1 = tableScan("ts1", "orderkey"); - TableScanNode ts2 = tableScan("ts2", "orderkey_0"); - ExchangeNode remoteExchange1 = systemPartitionedExchange(new PlanNodeId("re1"), REMOTE_STREAMING, ts1, ImmutableList.of(new Symbol("orderkey")), Optional.empty()); - ExchangeNode remoteExchange2 = systemPartitionedExchange(new PlanNodeId("re2"), REMOTE_STREAMING, ts2, ImmutableList.of(new Symbol("orderkey_0")), Optional.empty()); - ExchangeNode localExchange = systemPartitionedExchange(new PlanNodeId("le"), LOCAL, remoteExchange2, ImmutableList.of(new Symbol("orderkey_0")), Optional.empty()); + TableScanNode ts1 = tableScan("ts1", ImmutableList.of(new VariableReferenceExpression("orderkey", BIGINT))); + TableScanNode ts2 = tableScan("ts2", ImmutableList.of(new VariableReferenceExpression("orderkey_0", BIGINT))); + ExchangeNode remoteExchange1 = systemPartitionedExchange( + new PlanNodeId("re1"), + REMOTE_STREAMING, + ts1, + ImmutableList.of(new VariableReferenceExpression("orderkey", BIGINT)), + Optional.empty()); + ExchangeNode remoteExchange2 = systemPartitionedExchange( + new PlanNodeId("re2"), + REMOTE_STREAMING, + ts2, + ImmutableList.of(new VariableReferenceExpression("orderkey_0", BIGINT)), + Optional.empty()); + ExchangeNode localExchange = systemPartitionedExchange( + new PlanNodeId("le"), + LOCAL, + remoteExchange2, + ImmutableList.of(new VariableReferenceExpression("orderkey_0", BIGINT)), + Optional.empty()); JoinNode join = join("join", remoteExchange1, @@ -447,10 +463,15 @@ public void testRepartitionedJoinWithExchange() @Test public void testReplicatedJoinWithExchange() { - TableScanNode ts1 = tableScan("ts1", "orderkey"); - TableScanNode ts2 = tableScan("ts2", "orderkey_0"); + TableScanNode ts1 = tableScan("ts1", ImmutableList.of(new VariableReferenceExpression("orderkey", BIGINT))); + TableScanNode ts2 = tableScan("ts2", ImmutableList.of(new VariableReferenceExpression("orderkey_0", BIGINT))); ExchangeNode remoteExchange2 = replicatedExchange(new PlanNodeId("re2"), REMOTE_STREAMING, ts2); - ExchangeNode localExchange = systemPartitionedExchange(new PlanNodeId("le"), LOCAL, remoteExchange2, ImmutableList.of(new Symbol("orderkey_0")), Optional.empty()); + ExchangeNode localExchange = systemPartitionedExchange( + new PlanNodeId("le"), + LOCAL, + remoteExchange2, + ImmutableList.of(new VariableReferenceExpression("orderkey_0", BIGINT)), + Optional.empty()); JoinNode join = join("join", ts1, @@ -478,10 +499,10 @@ public void testUnion() { TableScanNode ts1 = tableScan("ts1", "orderkey"); TableScanNode ts2 = tableScan("ts2", "orderkey_0"); - ImmutableListMultimap.Builder outputMappings = ImmutableListMultimap.builder(); - outputMappings.put(new Symbol("orderkey_1"), new Symbol("orderkey")); - outputMappings.put(new Symbol("orderkey_1"), new Symbol("orderkey_0")); - UnionNode union = new UnionNode(new PlanNodeId("union"), ImmutableList.of(ts1, ts2), outputMappings.build(), ImmutableList.of(new Symbol("orderkey_1"))); + ImmutableListMultimap.Builder outputMappings = ImmutableListMultimap.builder(); + outputMappings.put(new VariableReferenceExpression("orderkey_1", BIGINT), new VariableReferenceExpression("orderkey", BIGINT)); + outputMappings.put(new VariableReferenceExpression("orderkey_1", BIGINT), new VariableReferenceExpression("orderkey_0", BIGINT)); + UnionNode union = new UnionNode(new PlanNodeId("union"), ImmutableList.of(ts1, ts2), outputMappings.build()); Map stats = ImmutableMap.of( "ts1", statsEstimate(ts1, 4000), "ts2", statsEstimate(ts2, 1000), @@ -709,22 +730,22 @@ CostAssertionBuilder hasUnknownComponents() private static PlanNodeStatsEstimate statsEstimate(PlanNode node, double outputSizeInBytes) { - return statsEstimate(node.getOutputSymbols(), outputSizeInBytes); + return statsEstimate(node.getOutputVariables(), outputSizeInBytes); } - private static PlanNodeStatsEstimate statsEstimate(Collection symbols, double outputSizeInBytes) + private static PlanNodeStatsEstimate statsEstimate(Collection variables, double outputSizeInBytes) { - checkArgument(symbols.size() > 0, "No symbols"); - checkArgument(ImmutableSet.copyOf(symbols).size() == symbols.size(), "Duplicate symbols"); + checkArgument(variables.size() > 0, "No variables"); + checkArgument(ImmutableSet.copyOf(variables).size() == variables.size(), "Duplicate variables"); - double rowCount = outputSizeInBytes / symbols.size() / AVERAGE_ROW_SIZE; + double rowCount = outputSizeInBytes / variables.size() / AVERAGE_ROW_SIZE; PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder() .setOutputRowCount(rowCount); - for (Symbol symbol : symbols) { - builder.addSymbolStatistics( - symbol, - SymbolStatsEstimate.builder() + for (VariableReferenceExpression variable : variables) { + builder.addVariableStatistics( + variable, + VariableStatsEstimate.builder() .setNullsFraction(0) .setAverageRowSize(AVERAGE_ROW_SIZE) .build()); @@ -734,11 +755,18 @@ private static PlanNodeStatsEstimate statsEstimate(Collection symbols, d private TableScanNode tableScan(String id, String... symbols) { - List symbolsList = Arrays.stream(symbols).map(Symbol::new).collect(toImmutableList()); - ImmutableMap.Builder assignments = ImmutableMap.builder(); + List variables = Arrays.stream(symbols) + .map(symbol -> new VariableReferenceExpression(symbol, BIGINT)) + .collect(toImmutableList()); + return tableScan(id, variables); + } + + private TableScanNode tableScan(String id, List variables) + { + ImmutableMap.Builder assignments = ImmutableMap.builder(); - for (Symbol symbol : symbolsList) { - assignments.put(symbol, new TpchColumnHandle("orderkey", BIGINT)); + for (VariableReferenceExpression variable : variables) { + assignments.put(variable, new TpchColumnHandle("orderkey", BIGINT)); } TpchTableHandle tableHandle = new TpchTableHandle("orders", 1.0); @@ -749,18 +777,18 @@ private TableScanNode tableScan(String id, String... symbols) tableHandle, TpchTransactionHandle.INSTANCE, Optional.of(new TpchTableLayoutHandle(tableHandle, TupleDomain.all()))), - symbolsList, + variables, assignments.build(), TupleDomain.all(), TupleDomain.all()); } - private PlanNode project(String id, PlanNode source, String symbol, Expression expression) + private PlanNode project(String id, PlanNode source, VariableReferenceExpression variable, Expression expression) { return new ProjectNode( new PlanNodeId(id), source, - Assignments.of(new Symbol(symbol), expression)); + Assignments.of(variable, expression)); } private AggregationNode aggregation(String id, PlanNode source) @@ -770,8 +798,8 @@ private AggregationNode aggregation(String id, PlanNode source) return new AggregationNode( new PlanNodeId(id), source, - ImmutableMap.of(new Symbol("count"), aggregation), - singleGroupingSet(source.getOutputSymbols()), + ImmutableMap.of(new VariableReferenceExpression("count", BIGINT), aggregation), + singleGroupingSet(source.getOutputVariables()), ImmutableList.of(), AggregationNode.Step.FINAL, Optional.empty(), @@ -788,7 +816,7 @@ private JoinNode join(String planNodeId, PlanNode left, PlanNode right, JoinNode ImmutableList.Builder criteria = ImmutableList.builder(); for (int i = 0; i < symbols.length; i += 2) { - criteria.add(new JoinNode.EquiJoinClause(new Symbol(symbols[i]), new Symbol(symbols[i + 1]))); + criteria.add(new JoinNode.EquiJoinClause(new VariableReferenceExpression(symbols[i], BIGINT), new VariableReferenceExpression(symbols[i + 1], BIGINT))); } return new JoinNode( @@ -797,9 +825,9 @@ private JoinNode join(String planNodeId, PlanNode left, PlanNode right, JoinNode left, right, criteria.build(), - ImmutableList.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + ImmutableList.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java index 9dbc189f723fc..12e52df096e6c 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java @@ -14,7 +14,7 @@ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -35,34 +35,34 @@ public void testExchange() tester().assertStatsFor(pb -> pb .exchange(exchangeBuilder -> exchangeBuilder - .addInputsSet(pb.symbol("i11", BIGINT), pb.symbol("i12", BIGINT), pb.symbol("i13", BIGINT), pb.symbol("i14", BIGINT)) - .addInputsSet(pb.symbol("i21", BIGINT), pb.symbol("i22", BIGINT), pb.symbol("i23", BIGINT), pb.symbol("i24", BIGINT)) + .addInputsSet(pb.variable("i11", BIGINT), pb.variable("i12", BIGINT), pb.variable("i13", BIGINT), pb.variable("i14", BIGINT)) + .addInputsSet(pb.variable("i21", BIGINT), pb.variable("i22", BIGINT), pb.variable("i23", BIGINT), pb.variable("i24", BIGINT)) .fixedHashDistributionParitioningScheme( - ImmutableList.of(pb.symbol("o1", BIGINT), pb.symbol("o2", BIGINT), pb.symbol("o3", BIGINT), pb.symbol("o4", BIGINT)), + ImmutableList.of(pb.variable("o1", BIGINT), pb.variable("o2", BIGINT), pb.variable("o3", BIGINT), pb.variable("o4", BIGINT)), emptyList()) - .addSource(pb.values(pb.symbol("i11", BIGINT), pb.symbol("i12", BIGINT), pb.symbol("i13", BIGINT), pb.symbol("i14", BIGINT))) - .addSource(pb.values(pb.symbol("i21", BIGINT), pb.symbol("i22", BIGINT), pb.symbol("i23", BIGINT), pb.symbol("i24", BIGINT))))) + .addSource(pb.values(pb.variable("i11", BIGINT), pb.variable("i12", BIGINT), pb.variable("i13", BIGINT), pb.variable("i14", BIGINT))) + .addSource(pb.values(pb.variable("i21", BIGINT), pb.variable("i22", BIGINT), pb.variable("i23", BIGINT), pb.variable("i24", BIGINT))))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i11", BIGINT), VariableStatsEstimate.builder() .setLowValue(1) .setHighValue(10) .setDistinctValuesCount(5) .setNullsFraction(0.3) .build()) - .addSymbolStatistics(new Symbol("i12"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i12", BIGINT), VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(3) .setDistinctValuesCount(4) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i13"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i13", BIGINT), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) .setNullsFraction(0.1) .build()) - .addSymbolStatistics(new Symbol("i14"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i14", BIGINT), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) @@ -71,21 +71,21 @@ public void testExchange() .build()) .withSourceStats(1, PlanNodeStatsEstimate.builder() .setOutputRowCount(20) - .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i21", BIGINT), VariableStatsEstimate.builder() .setLowValue(11) .setHighValue(20) .setNullsFraction(0.4) .build()) - .addSymbolStatistics(new Symbol("i22"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i22", BIGINT), VariableStatsEstimate.builder() .setLowValue(2) .setHighValue(7) .setDistinctValuesCount(3) .build()) - .addSymbolStatistics(new Symbol("i23"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i23", BIGINT), VariableStatsEstimate.builder() .setDistinctValuesCount(6) .setNullsFraction(0.2) .build()) - .addSymbolStatistics(new Symbol("i24"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i24", BIGINT), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) @@ -94,22 +94,22 @@ public void testExchange() .build()) .check(check -> check .outputRowsCount(30) - .symbolStats("o1", assertion -> assertion + .variableStats(new VariableReferenceExpression("o1", BIGINT), assertion -> assertion .lowValue(1) .highValue(20) .distinctValuesCountUnknown() .nullsFraction(0.3666666)) - .symbolStats("o2", assertion -> assertion + .variableStats(new VariableReferenceExpression("o2", BIGINT), assertion -> assertion .lowValue(0) .highValue(7) .distinctValuesCount(4) .nullsFractionUnknown()) - .symbolStats("o3", assertion -> assertion + .variableStats(new VariableReferenceExpression("o3", BIGINT), assertion -> assertion .lowValueUnknown() .highValueUnknown() .distinctValuesCount(6) .nullsFraction(0.1666667)) - .symbolStats("o4", assertion -> assertion + .variableStats(new VariableReferenceExpression("o4", BIGINT), assertion -> assertion .lowValue(10) .highValue(15) .distinctValuesCount(4) diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java index 70b74701cea07..11a86c5bf96ba 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java @@ -16,7 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.spi.type.DoubleType; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.sql.TestingRowExpressionTranslator; @@ -28,6 +28,7 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static java.lang.Double.NEGATIVE_INFINITY; @@ -40,14 +41,14 @@ public class TestFilterStatsCalculator { private static final VarcharType MEDIUM_VARCHAR_TYPE = VarcharType.createVarcharType(100); - private SymbolStatsEstimate xStats; - private SymbolStatsEstimate yStats; - private SymbolStatsEstimate zStats; - private SymbolStatsEstimate leftOpenStats; - private SymbolStatsEstimate rightOpenStats; - private SymbolStatsEstimate unknownRangeStats; - private SymbolStatsEstimate emptyRangeStats; - private SymbolStatsEstimate mediumVarcharStats; + private VariableStatsEstimate xStats; + private VariableStatsEstimate yStats; + private VariableStatsEstimate zStats; + private VariableStatsEstimate leftOpenStats; + private VariableStatsEstimate rightOpenStats; + private VariableStatsEstimate unknownRangeStats; + private VariableStatsEstimate emptyRangeStats; + private VariableStatsEstimate mediumVarcharStats; private FilterStatsCalculator statsCalculator; private PlanNodeStatsEstimate standardInputStatistics; private TypeProvider standardTypes; @@ -58,56 +59,56 @@ public class TestFilterStatsCalculator public void setUp() throws Exception { - xStats = SymbolStatsEstimate.builder() + xStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(40.0) .setLowValue(-10.0) .setHighValue(10.0) .setNullsFraction(0.25) .build(); - yStats = SymbolStatsEstimate.builder() + yStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(20.0) .setLowValue(0.0) .setHighValue(5.0) .setNullsFraction(0.5) .build(); - zStats = SymbolStatsEstimate.builder() + zStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(5.0) .setLowValue(-100.0) .setHighValue(100.0) .setNullsFraction(0.1) .build(); - leftOpenStats = SymbolStatsEstimate.builder() + leftOpenStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(NEGATIVE_INFINITY) .setHighValue(15.0) .setNullsFraction(0.1) .build(); - rightOpenStats = SymbolStatsEstimate.builder() + rightOpenStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(-15.0) .setHighValue(POSITIVE_INFINITY) .setNullsFraction(0.1) .build(); - unknownRangeStats = SymbolStatsEstimate.builder() + unknownRangeStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(NEGATIVE_INFINITY) .setHighValue(POSITIVE_INFINITY) .setNullsFraction(0.1) .build(); - emptyRangeStats = SymbolStatsEstimate.builder() + emptyRangeStats = VariableStatsEstimate.builder() .setAverageRowSize(0.0) .setDistinctValuesCount(0.0) .setLowValue(NaN) .setHighValue(NaN) .setNullsFraction(NaN) .build(); - mediumVarcharStats = SymbolStatsEstimate.builder() + mediumVarcharStats = VariableStatsEstimate.builder() .setAverageRowSize(85.0) .setDistinctValuesCount(165) .setLowValue(NEGATIVE_INFINITY) @@ -115,25 +116,25 @@ public void setUp() .setNullsFraction(0.34) .build(); standardInputStatistics = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("x"), xStats) - .addSymbolStatistics(new Symbol("y"), yStats) - .addSymbolStatistics(new Symbol("z"), zStats) - .addSymbolStatistics(new Symbol("leftOpen"), leftOpenStats) - .addSymbolStatistics(new Symbol("rightOpen"), rightOpenStats) - .addSymbolStatistics(new Symbol("unknownRange"), unknownRangeStats) - .addSymbolStatistics(new Symbol("emptyRange"), emptyRangeStats) - .addSymbolStatistics(new Symbol("mediumVarchar"), mediumVarcharStats) + .addVariableStatistics(new VariableReferenceExpression("x", DOUBLE), xStats) + .addVariableStatistics(new VariableReferenceExpression("y", DOUBLE), yStats) + .addVariableStatistics(new VariableReferenceExpression("z", DOUBLE), zStats) + .addVariableStatistics(new VariableReferenceExpression("leftOpen", DOUBLE), leftOpenStats) + .addVariableStatistics(new VariableReferenceExpression("rightOpen", DOUBLE), rightOpenStats) + .addVariableStatistics(new VariableReferenceExpression("unknownRange", DOUBLE), unknownRangeStats) + .addVariableStatistics(new VariableReferenceExpression("emptyRange", DOUBLE), emptyRangeStats) + .addVariableStatistics(new VariableReferenceExpression("mediumVarchar", MEDIUM_VARCHAR_TYPE), mediumVarcharStats) .setOutputRowCount(1000.0) .build(); standardTypes = TypeProvider.copyOf(ImmutableMap.builder() - .put(new Symbol("x"), DoubleType.DOUBLE) - .put(new Symbol("y"), DoubleType.DOUBLE) - .put(new Symbol("z"), DoubleType.DOUBLE) - .put(new Symbol("leftOpen"), DoubleType.DOUBLE) - .put(new Symbol("rightOpen"), DoubleType.DOUBLE) - .put(new Symbol("unknownRange"), DoubleType.DOUBLE) - .put(new Symbol("emptyRange"), DoubleType.DOUBLE) + .put(new Symbol("x"), DOUBLE) + .put(new Symbol("y"), DOUBLE) + .put(new Symbol("z"), DOUBLE) + .put(new Symbol("leftOpen"), DOUBLE) + .put(new Symbol("rightOpen"), DOUBLE) + .put(new Symbol("unknownRange"), DOUBLE) + .put(new Symbol("emptyRange"), DOUBLE) .put(new Symbol("mediumVarchar"), MEDIUM_VARCHAR_TYPE) .build()); @@ -151,24 +152,24 @@ public void testBooleanLiteralStats() assertExpression("false") .outputRowsCount(0.0) - .symbolStats("x", SymbolStatsAssertion::empty) - .symbolStats("y", SymbolStatsAssertion::empty) - .symbolStats("z", SymbolStatsAssertion::empty) - .symbolStats("leftOpen", SymbolStatsAssertion::empty) - .symbolStats("rightOpen", SymbolStatsAssertion::empty) - .symbolStats("emptyRange", SymbolStatsAssertion::empty) - .symbolStats("unknownRange", SymbolStatsAssertion::empty); + .variableStats(new VariableReferenceExpression("x", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("y", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("z", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("leftOpen", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("rightOpen", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("emptyRange", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("unknownRange", DOUBLE), VariableStatsAssertion::empty); // `null AND null` is interpreted as null assertExpression("cast(null as boolean) AND cast(null as boolean)") .outputRowsCount(0.0) - .symbolStats("x", SymbolStatsAssertion::empty) - .symbolStats("y", SymbolStatsAssertion::empty) - .symbolStats("z", SymbolStatsAssertion::empty) - .symbolStats("leftOpen", SymbolStatsAssertion::empty) - .symbolStats("rightOpen", SymbolStatsAssertion::empty) - .symbolStats("emptyRange", SymbolStatsAssertion::empty) - .symbolStats("unknownRange", SymbolStatsAssertion::empty); + .variableStats(new VariableReferenceExpression("x", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("y", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("z", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("leftOpen", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("rightOpen", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("emptyRange", DOUBLE), VariableStatsAssertion::empty) + .variableStats(new VariableReferenceExpression("unknownRange", DOUBLE), VariableStatsAssertion::empty); // more complicated expressions with null assertExpression("cast(null as boolean) OR sin(x) > x").outputRowsCount(NaN); @@ -187,8 +188,8 @@ public void testComparison() double lessThan3Rows = 487.5; assertExpression("x < 3e0") .outputRowsCount(lessThan3Rows) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(-10) .highValue(3) .distinctValuesCount(26) @@ -201,8 +202,8 @@ public void testComparison() for (String xEquals : ImmutableList.of("x = %s", "%s = x", "COALESCE(x * CAST(NULL AS BIGINT), x) = %s", "%s = CAST(x AS DOUBLE)")) { assertExpression(format(xEquals, minusThree)) .outputRowsCount(18.75) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(-3) .highValue(-3) .distinctValuesCount(1) @@ -212,8 +213,8 @@ public void testComparison() for (String xLessThan : ImmutableList.of("x < %s", "%s > x", "%s > CAST(x AS DOUBLE)")) { assertExpression(format(xLessThan, minusThree)) .outputRowsCount(262.5) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(-10) .highValue(-3) .distinctValuesCount(14) @@ -227,8 +228,8 @@ public void testOrStats() { assertExpression("x < 0e0 OR x < DOUBLE '-7.5'") .outputRowsCount(375) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(-10.0) .highValue(0.0) .distinctValuesCount(20.0) @@ -236,8 +237,8 @@ public void testOrStats() assertExpression("x = 0e0 OR x = DOUBLE '-7.5'") .outputRowsCount(37.5) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(-7.5) .highValue(0.0) .distinctValuesCount(2.0) @@ -245,8 +246,8 @@ public void testOrStats() assertExpression("x = 1e0 OR x = 3e0") .outputRowsCount(37.5) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(1) .highValue(3) .distinctValuesCount(2) @@ -254,8 +255,8 @@ public void testOrStats() assertExpression("x = 1e0 OR 'a' = 'b' OR x = 3e0") .outputRowsCount(37.5) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(1) .highValue(3) .distinctValuesCount(2) @@ -279,8 +280,8 @@ public void testAndStats() { assertExpression("x < 0e0 AND x > DOUBLE '-7.5'") .outputRowsCount(281.25) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(-7.5) .highValue(0.0) .distinctValuesCount(15.0) @@ -289,14 +290,14 @@ public void testAndStats() // Impossible, with symbol-to-expression comparisons assertExpression("x = (0e0 + 1e0) AND x = (0e0 + 3e0)") .outputRowsCount(0) - .symbolStats(new Symbol("x"), SymbolStatsAssertion::emptyRange) - .symbolStats(new Symbol("y"), SymbolStatsAssertion::emptyRange); + .variableStats(new VariableReferenceExpression("x", DOUBLE), VariableStatsAssertion::emptyRange) + .variableStats(new VariableReferenceExpression("y", DOUBLE), VariableStatsAssertion::emptyRange); // first argument unknown assertExpression("json_array_contains(JSON '[]', x) AND x < 0e0") .outputRowsCount(337.5) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.lowValue(-10) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.lowValue(-10) .highValue(0) .distinctValuesCount(20) .nullsFraction(0)); @@ -304,8 +305,8 @@ public void testAndStats() // second argument unknown assertExpression("x < 0e0 AND json_array_contains(JSON '[]', x)") .outputRowsCount(337.5) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.lowValue(-10) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.lowValue(-10) .highValue(0) .distinctValuesCount(20) .nullsFraction(0)); @@ -323,23 +324,23 @@ public void testNotStats() { assertExpression("NOT(x < 0e0)") .outputRowsCount(625) // FIXME - nulls shouldn't be restored - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(-10.0) .highValue(10.0) .distinctValuesCount(20.0) .nullsFraction(0.4)) // FIXME - nulls shouldn't be restored - .symbolStats(new Symbol("y"), symbolAssert -> symbolAssert.isEqualTo(yStats)); + .variableStats(new VariableReferenceExpression("y", DOUBLE), variableAssert -> variableAssert.isEqualTo(yStats)); assertExpression("NOT(x IS NULL)") .outputRowsCount(750) - .symbolStats(new Symbol("x"), symbolAssert -> - symbolAssert.averageRowSize(4.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableAssert -> + variableAssert.averageRowSize(4.0) .lowValue(-10.0) .highValue(10.0) .distinctValuesCount(40.0) .nullsFraction(0)) - .symbolStats(new Symbol("y"), symbolAssert -> symbolAssert.isEqualTo(yStats)); + .variableStats(new VariableReferenceExpression("y", DOUBLE), variableAssert -> variableAssert.isEqualTo(yStats)); assertExpression("NOT(json_array_contains(JSON '[]', x))") .outputRowsCountUnknown(); @@ -350,14 +351,14 @@ public void testIsNullFilter() { assertExpression("x IS NULL") .outputRowsCount(250.0) - .symbolStats(new Symbol("x"), symbolStats -> - symbolStats.distinctValuesCount(0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(0) .emptyRange() .nullsFraction(1.0)); assertExpression("emptyRange IS NULL") .outputRowsCount(1000.0) - .symbolStats(new Symbol("emptyRange"), SymbolStatsAssertion::empty); + .variableStats(new VariableReferenceExpression("emptyRange", DOUBLE), VariableStatsAssertion::empty); } @Test @@ -365,15 +366,15 @@ public void testIsNotNullFilter() { assertExpression("x IS NOT NULL") .outputRowsCount(750.0) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(40.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(40.0) .lowValue(-10.0) .highValue(10.0) .nullsFraction(0.0)); assertExpression("emptyRange IS NOT NULL") .outputRowsCount(0.0) - .symbolStats("emptyRange", SymbolStatsAssertion::empty); + .variableStats(new VariableReferenceExpression("emptyRange", DOUBLE), VariableStatsAssertion::empty); } @Test @@ -382,8 +383,8 @@ public void testBetweenOperatorFilter() // Only right side cut assertExpression("x BETWEEN 7.5e0 AND 12e0") .outputRowsCount(93.75) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(5.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(5.0) .lowValue(7.5) .highValue(10.0) .nullsFraction(0.0)); @@ -391,15 +392,15 @@ public void testBetweenOperatorFilter() // Only left side cut assertExpression("x BETWEEN DOUBLE '-12' AND DOUBLE '-7.5'") .outputRowsCount(93.75) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(5.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(5.0) .lowValue(-10) .highValue(-7.5) .nullsFraction(0.0)); assertExpression("x BETWEEN -12e0 AND -7.5e0") .outputRowsCount(93.75) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(5.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(5.0) .lowValue(-10) .highValue(-7.5) .nullsFraction(0.0)); @@ -407,8 +408,8 @@ public void testBetweenOperatorFilter() // Both sides cut assertExpression("x BETWEEN DOUBLE '-2.5' AND 2.5e0") .outputRowsCount(187.5) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(10.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(10.0) .lowValue(-2.5) .highValue(2.5) .nullsFraction(0.0)); @@ -416,8 +417,8 @@ public void testBetweenOperatorFilter() // Both sides cut unknownRange assertExpression("unknownRange BETWEEN 2.72e0 AND 3.14e0") .outputRowsCount(112.5) - .symbolStats("unknownRange", symbolStats -> - symbolStats.distinctValuesCount(6.25) + .variableStats(new VariableReferenceExpression("unknownRange", DOUBLE), variableStats -> + variableStats.distinctValuesCount(6.25) .lowValue(2.72) .highValue(3.14) .nullsFraction(0.0)); @@ -425,8 +426,8 @@ public void testBetweenOperatorFilter() // Left side open, cut on open side assertExpression("leftOpen BETWEEN DOUBLE '-10' AND 10e0") .outputRowsCount(180.0) - .symbolStats("leftOpen", symbolStats -> - symbolStats.distinctValuesCount(10.0) + .variableStats(new VariableReferenceExpression("leftOpen", DOUBLE), variableStats -> + variableStats.distinctValuesCount(10.0) .lowValue(-10.0) .highValue(10.0) .nullsFraction(0.0)); @@ -434,8 +435,8 @@ public void testBetweenOperatorFilter() // Right side open, cut on open side assertExpression("rightOpen BETWEEN DOUBLE '-10' AND 10e0") .outputRowsCount(180.0) - .symbolStats("rightOpen", symbolStats -> - symbolStats.distinctValuesCount(10.0) + .variableStats(new VariableReferenceExpression("rightOpen", DOUBLE), variableStats -> + variableStats.distinctValuesCount(10.0) .lowValue(-10.0) .highValue(10.0) .nullsFraction(0.0)); @@ -443,13 +444,13 @@ public void testBetweenOperatorFilter() // Filter all assertExpression("y BETWEEN 27.5e0 AND 107e0") .outputRowsCount(0.0) - .symbolStats("y", SymbolStatsAssertion::empty); + .variableStats(new VariableReferenceExpression("y", DOUBLE), VariableStatsAssertion::empty); // Filter nothing assertExpression("y BETWEEN DOUBLE '-100' AND 100e0") .outputRowsCount(500.0) - .symbolStats("y", symbolStats -> - symbolStats.distinctValuesCount(20.0) + .variableStats(new VariableReferenceExpression("y", DOUBLE), variableStats -> + variableStats.distinctValuesCount(20.0) .lowValue(0.0) .highValue(5.0) .nullsFraction(0.0)); @@ -457,8 +458,8 @@ public void testBetweenOperatorFilter() // Filter non exact match assertExpression("z BETWEEN DOUBLE '-100' AND 100e0") .outputRowsCount(900.0) - .symbolStats("z", symbolStats -> - symbolStats.distinctValuesCount(5.0) + .variableStats(new VariableReferenceExpression("z", DOUBLE), variableStats -> + variableStats.distinctValuesCount(5.0) .lowValue(-100.0) .highValue(100.0) .nullsFraction(0.0)); @@ -474,8 +475,8 @@ public void testSymbolEqualsSameSymbolFilter() { assertExpression("x = x") .outputRowsCount(750) - .symbolStats("x", symbolStats -> - SymbolStatsEstimate.builder() + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(40.0) .setLowValue(-10.0) @@ -491,29 +492,29 @@ public void testInPredicateFilter() // One value in range assertExpression("x IN (7.5e0)") .outputRowsCount(18.75) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(1.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(1.0) .lowValue(7.5) .highValue(7.5) .nullsFraction(0.0)); assertExpression("x IN (DOUBLE '-7.5')") .outputRowsCount(18.75) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(1.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(1.0) .lowValue(-7.5) .highValue(-7.5) .nullsFraction(0.0)); assertExpression("x IN (BIGINT '2' + 5.5e0)") .outputRowsCount(18.75) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(1.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(1.0) .lowValue(7.5) .highValue(7.5) .nullsFraction(0.0)); assertExpression("x IN (-7.5e0)") .outputRowsCount(18.75) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(1.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(1.0) .lowValue(-7.5) .highValue(-7.5) .nullsFraction(0.0)); @@ -521,14 +522,14 @@ public void testInPredicateFilter() // Multiple values in range assertExpression("x IN (1.5e0, 2.5e0, 7.5e0)") .outputRowsCount(56.25) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(3.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(3.0) .lowValue(1.5) .highValue(7.5) .nullsFraction(0.0)) - .symbolStats("y", symbolStats -> + .variableStats(new VariableReferenceExpression("y", DOUBLE), variableStats -> // Symbol not involved in the comparison should have stats basically unchanged - symbolStats.distinctValuesCount(20.0) + variableStats.distinctValuesCount(20.0) .lowValue(0.0) .highValue(5) .nullsFraction(0.5)); @@ -536,8 +537,8 @@ public void testInPredicateFilter() // Multiple values some in some out of range assertExpression("x IN (DOUBLE '-42', 1.5e0, 2.5e0, 7.5e0, 314e0)") .outputRowsCount(56.25) - .symbolStats("x", symbolStats -> - symbolStats.distinctValuesCount(3.0) + .variableStats(new VariableReferenceExpression("x", DOUBLE), variableStats -> + variableStats.distinctValuesCount(3.0) .lowValue(1.5) .highValue(7.5) .nullsFraction(0.0)); @@ -545,8 +546,8 @@ public void testInPredicateFilter() // Multiple values in unknown range assertExpression("unknownRange IN (DOUBLE '-42', 1.5e0, 2.5e0, 7.5e0, 314e0)") .outputRowsCount(90.0) - .symbolStats("unknownRange", symbolStats -> - symbolStats.distinctValuesCount(5.0) + .variableStats(new VariableReferenceExpression("unknownRange", DOUBLE), variableStats -> + variableStats.distinctValuesCount(5.0) .lowValue(-42.0) .highValue(314.0) .nullsFraction(0.0)); @@ -554,26 +555,26 @@ public void testInPredicateFilter() // Casted literals as value assertExpression(format("mediumVarchar IN (CAST('abc' AS %s))", MEDIUM_VARCHAR_TYPE.toString())) .outputRowsCount(4) - .symbolStats("mediumVarchar", symbolStats -> - symbolStats.distinctValuesCount(1) + .variableStats(new VariableReferenceExpression("mediumVarchar", MEDIUM_VARCHAR_TYPE), variableStats -> + variableStats.distinctValuesCount(1) .nullsFraction(0.0)); assertExpression(format("mediumVarchar IN (CAST('abc' AS %1$s), CAST('def' AS %1$s))", MEDIUM_VARCHAR_TYPE.toString())) .outputRowsCount(8) - .symbolStats("mediumVarchar", symbolStats -> - symbolStats.distinctValuesCount(2) + .variableStats(new VariableReferenceExpression("mediumVarchar", MEDIUM_VARCHAR_TYPE), variableStats -> + variableStats.distinctValuesCount(2) .nullsFraction(0.0)); // No value in range assertExpression("y IN (DOUBLE '-42', 6e0, 31.1341e0, DOUBLE '-0.000000002', 314e0)") .outputRowsCount(0.0) - .symbolStats("y", SymbolStatsAssertion::empty); + .variableStats(new VariableReferenceExpression("y", DOUBLE), VariableStatsAssertion::empty); // More values in range than distinct values assertExpression("z IN (DOUBLE '-1', 3.14e0, 0e0, 1e0, 2e0, 3e0, 4e0, 5e0, 6e0, 7e0, 8e0, DOUBLE '-2')") .outputRowsCount(900.0) - .symbolStats("z", symbolStats -> - symbolStats.distinctValuesCount(5.0) + .variableStats(new VariableReferenceExpression("z", DOUBLE), variableStats -> + variableStats.distinctValuesCount(5.0) .lowValue(-2.0) .highValue(8.0) .nullsFraction(0.0)); @@ -581,8 +582,8 @@ public void testInPredicateFilter() // Values in weird order assertExpression("z IN (DOUBLE '-1', 1e0, 0e0)") .outputRowsCount(540.0) - .symbolStats("z", symbolStats -> - symbolStats.distinctValuesCount(3.0) + .variableStats(new VariableReferenceExpression("z", DOUBLE), variableStats -> + variableStats.distinctValuesCount(3.0) .lowValue(-1.0) .highValue(1.0) .nullsFraction(0.0)); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsRule.java index c790c7f96c9f1..646e53a30bc55 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsRule.java @@ -14,11 +14,12 @@ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -48,22 +49,22 @@ public void testEstimatableFilter() { tester().assertStatsFor(pb -> pb .filter(expression("i1 = 5"), - pb.values(pb.symbol("i1"), pb.symbol("i2"), pb.symbol("i3")))) + pb.values(pb.variable("i1", BIGINT), pb.variable("i2", BIGINT), pb.variable("i3", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(new Symbol("i1"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i1", BIGINT), VariableStatsEstimate.builder() .setLowValue(1) .setHighValue(10) .setDistinctValuesCount(5) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i2"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i2", BIGINT), VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(3) .setDistinctValuesCount(4) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i3"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i3", BIGINT), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) @@ -72,19 +73,19 @@ public void testEstimatableFilter() .build()) .check(check -> check .outputRowsCount(2) - .symbolStats("i1", assertion -> assertion + .variableStats(new VariableReferenceExpression("i1", BIGINT), assertion -> assertion .lowValue(5) .highValue(5) .distinctValuesCount(1) .dataSizeUnknown() .nullsFraction(0)) - .symbolStats("i2", assertion -> assertion + .variableStats(new VariableReferenceExpression("i2", BIGINT), assertion -> assertion .lowValue(0) .highValue(3) .dataSizeUnknown() .distinctValuesCount(2) .nullsFraction(0)) - .symbolStats("i3", assertion -> assertion + .variableStats(new VariableReferenceExpression("i3", BIGINT), assertion -> assertion .lowValue(10) .highValue(15) .dataSizeUnknown() @@ -93,22 +94,22 @@ public void testEstimatableFilter() defaultFilterTester.assertStatsFor(pb -> pb .filter(expression("i1 = 5"), - pb.values(pb.symbol("i1"), pb.symbol("i2"), pb.symbol("i3")))) + pb.values(pb.variable("i1", BIGINT), pb.variable("i2", BIGINT), pb.variable("i3", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(new Symbol("i1"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i1", BIGINT), VariableStatsEstimate.builder() .setLowValue(1) .setHighValue(10) .setDistinctValuesCount(5) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i2"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i2", BIGINT), VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(3) .setDistinctValuesCount(4) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i3"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i3", BIGINT), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) @@ -117,19 +118,19 @@ public void testEstimatableFilter() .build()) .check(check -> check .outputRowsCount(2) - .symbolStats("i1", assertion -> assertion + .variableStats(new VariableReferenceExpression("i1", BIGINT), assertion -> assertion .lowValue(5) .highValue(5) .distinctValuesCount(1) .dataSizeUnknown() .nullsFraction(0)) - .symbolStats("i2", assertion -> assertion + .variableStats(new VariableReferenceExpression("i2", BIGINT), assertion -> assertion .lowValue(0) .highValue(3) .dataSizeUnknown() .distinctValuesCount(2) .nullsFraction(0)) - .symbolStats("i3", assertion -> assertion + .variableStats(new VariableReferenceExpression("i3", BIGINT), assertion -> assertion .lowValue(10) .highValue(15) .dataSizeUnknown() @@ -144,22 +145,22 @@ public void testUnestimatableFunction() tester() .assertStatsFor(pb -> pb .filter(expression("sin(i1) = 1"), - pb.values(pb.symbol("i1"), pb.symbol("i2"), pb.symbol("i3")))) + pb.values(pb.variable("i1", BIGINT), pb.variable("i2", BIGINT), pb.variable("i3", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(new Symbol("i1"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i1", BIGINT), VariableStatsEstimate.builder() .setLowValue(1) .setHighValue(10) .setDistinctValuesCount(5) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i2"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i2", BIGINT), VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(3) .setDistinctValuesCount(4) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i3"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i3", BIGINT), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) @@ -171,22 +172,22 @@ public void testUnestimatableFunction() // can't estimate function, but default filter factor is turned on defaultFilterTester.assertStatsFor(pb -> pb .filter(expression("sin(i1) = 1"), - pb.values(pb.symbol("i1"), pb.symbol("i2"), pb.symbol("i3")))) + pb.values(pb.variable("i1", BIGINT), pb.variable("i2", BIGINT), pb.variable("i3", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(new Symbol("i1"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i1", BIGINT), VariableStatsEstimate.builder() .setLowValue(1) .setHighValue(10) .setDistinctValuesCount(5) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i2"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i2", BIGINT), VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(3) .setDistinctValuesCount(4) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i3"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("i3", BIGINT), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) @@ -195,19 +196,19 @@ public void testUnestimatableFunction() .build()) .check(check -> check .outputRowsCount(9) - .symbolStats("i1", assertion -> assertion + .variableStats(new VariableReferenceExpression("i1", BIGINT), assertion -> assertion .lowValue(1) .highValue(10) .dataSizeUnknown() .distinctValuesCount(5) .nullsFraction(0)) - .symbolStats("i2", assertion -> assertion + .variableStats(new VariableReferenceExpression("i2", BIGINT), assertion -> assertion .lowValue(0) .highValue(3) .dataSizeUnknown() .distinctValuesCount(4) .nullsFraction(0)) - .symbolStats("i3", assertion -> assertion + .variableStats(new VariableReferenceExpression("i3", BIGINT), assertion -> assertion .lowValue(10) .highValue(15) .dataSizeUnknown() diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java index 80a002ae9d345..c89bc047d2c03 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java @@ -14,6 +14,7 @@ package com.facebook.presto.cost; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; @@ -21,6 +22,7 @@ import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -42,12 +44,12 @@ public class TestJoinStatsRule extends BaseStatsCalculatorTest { - private static final String LEFT_JOIN_COLUMN = "left_join_column"; - private static final String LEFT_JOIN_COLUMN_2 = "left_join_column_2"; - private static final String RIGHT_JOIN_COLUMN = "right_join_column"; - private static final String RIGHT_JOIN_COLUMN_2 = "right_join_column_2"; - private static final String LEFT_OTHER_COLUMN = "left_column"; - private static final String RIGHT_OTHER_COLUMN = "right_column"; + private static final VariableReferenceExpression LEFT_JOIN_COLUMN = new VariableReferenceExpression("left_join_column", BIGINT); + private static final VariableReferenceExpression LEFT_JOIN_COLUMN_2 = new VariableReferenceExpression("left_join_column_2", BIGINT); + private static final VariableReferenceExpression RIGHT_JOIN_COLUMN = new VariableReferenceExpression("right_join_column", DOUBLE); + private static final VariableReferenceExpression RIGHT_JOIN_COLUMN_2 = new VariableReferenceExpression("right_join_column_2", DOUBLE); + private static final VariableReferenceExpression LEFT_OTHER_COLUMN = new VariableReferenceExpression("left_column", BIGINT); + private static final VariableReferenceExpression RIGHT_OTHER_COLUMN = new VariableReferenceExpression("right_column", DOUBLE); private static final double LEFT_ROWS_COUNT = 500.0; private static final double RIGHT_ROWS_COUNT = 1000.0; @@ -65,18 +67,27 @@ public class TestJoinStatsRule private static final int RIGHT_JOIN_COLUMN_NDV = 15; private static final int RIGHT_JOIN_COLUMN_2_NDV = 15; - private static final SymbolStatistics LEFT_JOIN_COLUMN_STATS = - symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, LEFT_JOIN_COLUMN_NULLS, LEFT_JOIN_COLUMN_NDV); - private static final SymbolStatistics LEFT_JOIN_COLUMN_2_STATS = - symbolStatistics(LEFT_JOIN_COLUMN_2, 0.0, 200.0, LEFT_JOIN_COLUMN_2_NULLS, LEFT_JOIN_COLUMN_2_NDV); - private static final SymbolStatistics LEFT_OTHER_COLUMN_STATS = - symbolStatistics(LEFT_OTHER_COLUMN, 42, 42, 0.42, 1); - private static final SymbolStatistics RIGHT_JOIN_COLUMN_STATS = - symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, RIGHT_JOIN_COLUMN_NULLS, RIGHT_JOIN_COLUMN_NDV); - private static final SymbolStatistics RIGHT_JOIN_COLUMN_2_STATS = - symbolStatistics(RIGHT_JOIN_COLUMN_2, 100.0, 200.0, RIGHT_JOIN_COLUMN_2_NULLS, RIGHT_JOIN_COLUMN_2_NDV); - private static final SymbolStatistics RIGHT_OTHER_COLUMN_STATS = - symbolStatistics(RIGHT_OTHER_COLUMN, 24, 24, 0.24, 1); + private static final TypeProvider TYPES = TypeProvider.copyOf(ImmutableMap.builder() + .put(new Symbol(LEFT_JOIN_COLUMN.getName()), LEFT_JOIN_COLUMN.getType()) + .put(new Symbol(LEFT_JOIN_COLUMN_2.getName()), LEFT_JOIN_COLUMN_2.getType()) + .put(new Symbol(RIGHT_JOIN_COLUMN.getName()), RIGHT_JOIN_COLUMN.getType()) + .put(new Symbol(RIGHT_JOIN_COLUMN_2.getName()), RIGHT_JOIN_COLUMN_2.getType()) + .put(new Symbol(LEFT_OTHER_COLUMN.getName()), LEFT_OTHER_COLUMN.getType()) + .put(new Symbol(RIGHT_OTHER_COLUMN.getName()), RIGHT_OTHER_COLUMN.getType()) + .build()); + + private static final VariableStatistics LEFT_JOIN_COLUMN_STATS = + variableStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, LEFT_JOIN_COLUMN_NULLS, LEFT_JOIN_COLUMN_NDV); + private static final VariableStatistics LEFT_JOIN_COLUMN_2_STATS = + variableStatistics(LEFT_JOIN_COLUMN_2, 0.0, 200.0, LEFT_JOIN_COLUMN_2_NULLS, LEFT_JOIN_COLUMN_2_NDV); + private static final VariableStatistics LEFT_OTHER_COLUMN_STATS = + variableStatistics(LEFT_OTHER_COLUMN, 42, 42, 0.42, 1); + private static final VariableStatistics RIGHT_JOIN_COLUMN_STATS = + variableStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, RIGHT_JOIN_COLUMN_NULLS, RIGHT_JOIN_COLUMN_NDV); + private static final VariableStatistics RIGHT_JOIN_COLUMN_2_STATS = + variableStatistics(RIGHT_JOIN_COLUMN_2, 100.0, 200.0, RIGHT_JOIN_COLUMN_2_NULLS, RIGHT_JOIN_COLUMN_2_NDV); + private static final VariableStatistics RIGHT_OTHER_COLUMN_STATS = + variableStatistics(RIGHT_OTHER_COLUMN, 24, 24, 0.24, 1); private static final PlanNodeStatsEstimate LEFT_STATS = planNodeStats(LEFT_ROWS_COUNT, LEFT_JOIN_COLUMN_STATS, LEFT_OTHER_COLUMN_STATS); @@ -90,22 +101,14 @@ public class TestJoinStatsRule new FilterStatsCalculator(METADATA, new ScalarStatsCalculator(METADATA), NORMALIZER), NORMALIZER, 1.0); - private static final TypeProvider TYPES = TypeProvider.copyOf(ImmutableMap.builder() - .put(new Symbol(LEFT_JOIN_COLUMN), BIGINT) - .put(new Symbol(LEFT_JOIN_COLUMN_2), DOUBLE) - .put(new Symbol(RIGHT_JOIN_COLUMN), BIGINT) - .put(new Symbol(RIGHT_JOIN_COLUMN_2), DOUBLE) - .put(new Symbol(LEFT_OTHER_COLUMN), DOUBLE) - .put(new Symbol(RIGHT_OTHER_COLUMN), BIGINT) - .build()); @Test public void testStatsForInnerJoin() { double innerJoinRowCount = LEFT_ROWS_COUNT * RIGHT_ROWS_COUNT / LEFT_JOIN_COLUMN_NDV * LEFT_JOIN_COLUMN_NON_NULLS * RIGHT_JOIN_COLUMN_NON_NULLS; PlanNodeStatsEstimate innerJoinStats = planNodeStats(innerJoinRowCount, - symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), - symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), + variableStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), + variableStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), LEFT_OTHER_COLUMN_STATS, RIGHT_OTHER_COLUMN_STATS); assertJoinStats(INNER, LEFT_STATS, RIGHT_STATS, innerJoinStats); @@ -117,19 +120,19 @@ public void testStatsForInnerJoinWithRepeatedClause() double innerJoinRowCount = LEFT_ROWS_COUNT * RIGHT_ROWS_COUNT / LEFT_JOIN_COLUMN_NDV * LEFT_JOIN_COLUMN_NON_NULLS * RIGHT_JOIN_COLUMN_NON_NULLS // driver join clause * UNKNOWN_FILTER_COEFFICIENT; // auxiliary join clause PlanNodeStatsEstimate innerJoinStats = planNodeStats(innerJoinRowCount, - symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), - symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), + variableStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), + variableStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), LEFT_OTHER_COLUMN_STATS, RIGHT_OTHER_COLUMN_STATS); tester().assertStatsFor(pb -> { - Symbol leftJoinColumnSymbol = pb.symbol(LEFT_JOIN_COLUMN, BIGINT); - Symbol rightJoinColumnSymbol = pb.symbol(RIGHT_JOIN_COLUMN, DOUBLE); - Symbol leftOtherColumnSymbol = pb.symbol(LEFT_OTHER_COLUMN, BIGINT); - Symbol rightOtherColumnSymbol = pb.symbol(RIGHT_OTHER_COLUMN, DOUBLE); + VariableReferenceExpression leftJoinColumnVariable = pb.variable(LEFT_JOIN_COLUMN); + VariableReferenceExpression rightJoinColumnVariable = pb.variable(RIGHT_JOIN_COLUMN); + VariableReferenceExpression leftOtherColumnVariable = pb.variable(LEFT_OTHER_COLUMN); + VariableReferenceExpression rightOtherColumnVariable = pb.variable(RIGHT_OTHER_COLUMN); return pb - .join(INNER, pb.values(leftJoinColumnSymbol, leftOtherColumnSymbol), - pb.values(rightJoinColumnSymbol, rightOtherColumnSymbol), - new EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol), new EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol)); + .join(INNER, pb.values(leftJoinColumnVariable, leftOtherColumnVariable), + pb.values(rightJoinColumnVariable, rightOtherColumnVariable), + new EquiJoinClause(pb.variable(leftJoinColumnVariable), pb.variable(rightJoinColumnVariable)), new EquiJoinClause(pb.variable(leftJoinColumnVariable), pb.variable(rightJoinColumnVariable))); }).withSourceStats(0, LEFT_STATS) .withSourceStats(1, RIGHT_STATS) .check(stats -> stats.equalTo(innerJoinStats)); @@ -142,20 +145,20 @@ public void testStatsForInnerJoinWithTwoEquiClauses() LEFT_ROWS_COUNT * RIGHT_ROWS_COUNT / LEFT_JOIN_COLUMN_2_NDV * LEFT_JOIN_COLUMN_2_NON_NULLS * RIGHT_JOIN_COLUMN_2_NON_NULLS // driver join clause * UNKNOWN_FILTER_COEFFICIENT; // auxiliary join clause PlanNodeStatsEstimate innerJoinStats = planNodeStats(innerJoinRowCount, - symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), - symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), - symbolStatistics(LEFT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV), - symbolStatistics(RIGHT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV)); + variableStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), + variableStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), + variableStatistics(LEFT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV), + variableStatistics(RIGHT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV)); tester().assertStatsFor(pb -> { - Symbol leftJoinColumnSymbol = pb.symbol(LEFT_JOIN_COLUMN, BIGINT); - Symbol rightJoinColumnSymbol = pb.symbol(RIGHT_JOIN_COLUMN, DOUBLE); - Symbol leftJoinColumnSymbol2 = pb.symbol(LEFT_JOIN_COLUMN_2, BIGINT); - Symbol rightJoinColumnSymbol2 = pb.symbol(RIGHT_JOIN_COLUMN_2, DOUBLE); + VariableReferenceExpression leftJoinColumnVariable = pb.variable(LEFT_JOIN_COLUMN); + VariableReferenceExpression rightJoinColumnVariable = pb.variable(RIGHT_JOIN_COLUMN); + VariableReferenceExpression leftJoinColumnVariable2 = pb.variable(LEFT_JOIN_COLUMN_2); + VariableReferenceExpression rightJoinColumnVariable2 = pb.variable(RIGHT_JOIN_COLUMN_2); return pb - .join(INNER, pb.values(leftJoinColumnSymbol, leftJoinColumnSymbol2), - pb.values(rightJoinColumnSymbol, rightJoinColumnSymbol2), - new EquiJoinClause(leftJoinColumnSymbol2, rightJoinColumnSymbol2), new EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol)); + .join(INNER, pb.values(leftJoinColumnVariable, leftJoinColumnVariable2), + pb.values(rightJoinColumnVariable, rightJoinColumnVariable2), + new EquiJoinClause(pb.variable(leftJoinColumnVariable2), pb.variable(rightJoinColumnVariable2)), new EquiJoinClause(pb.variable(leftJoinColumnVariable), pb.variable(rightJoinColumnVariable))); }).withSourceStats(0, planNodeStats(LEFT_ROWS_COUNT, LEFT_JOIN_COLUMN_STATS, LEFT_JOIN_COLUMN_2_STATS)) .withSourceStats(1, planNodeStats(RIGHT_ROWS_COUNT, RIGHT_JOIN_COLUMN_STATS, RIGHT_JOIN_COLUMN_2_STATS)) .check(stats -> stats.equalTo(innerJoinStats)); @@ -169,22 +172,22 @@ public void testStatsForInnerJoinWithTwoEquiClausesAndNonEqualityFunction() * UNKNOWN_FILTER_COEFFICIENT // auxiliary join clause * 0.3333333333; // LEFT_JOIN_COLUMN < 10 non equality filter PlanNodeStatsEstimate innerJoinStats = planNodeStats(innerJoinRowCount, - symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 10.0, 0.0, RIGHT_JOIN_COLUMN_NDV * 0.3333333333), - symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), - symbolStatistics(LEFT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV), - symbolStatistics(RIGHT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV)); + variableStatistics(LEFT_JOIN_COLUMN, 5.0, 10.0, 0.0, RIGHT_JOIN_COLUMN_NDV * 0.3333333333), + variableStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), + variableStatistics(LEFT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV), + variableStatistics(RIGHT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV)); tester().assertStatsFor(pb -> { - Symbol leftJoinColumnSymbol = pb.symbol(LEFT_JOIN_COLUMN, BIGINT); - Symbol rightJoinColumnSymbol = pb.symbol(RIGHT_JOIN_COLUMN, DOUBLE); - Symbol leftJoinColumnSymbol2 = pb.symbol(LEFT_JOIN_COLUMN_2, BIGINT); - Symbol rightJoinColumnSymbol2 = pb.symbol(RIGHT_JOIN_COLUMN_2, DOUBLE); - ComparisonExpression leftJoinColumnLessThanTen = new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, leftJoinColumnSymbol.toSymbolReference(), new LongLiteral("10")); + VariableReferenceExpression leftJoinColumn = pb.variable(LEFT_JOIN_COLUMN); + VariableReferenceExpression rightJoinColumn = pb.variable(RIGHT_JOIN_COLUMN); + VariableReferenceExpression leftJoinColumn2 = pb.variable(LEFT_JOIN_COLUMN_2); + VariableReferenceExpression rightJoinColumn2 = pb.variable(RIGHT_JOIN_COLUMN_2); + ComparisonExpression leftJoinColumnLessThanTen = new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference(leftJoinColumn.getName()), new LongLiteral("10")); return pb - .join(INNER, pb.values(leftJoinColumnSymbol, leftJoinColumnSymbol2), - pb.values(rightJoinColumnSymbol, rightJoinColumnSymbol2), - ImmutableList.of(new EquiJoinClause(leftJoinColumnSymbol2, rightJoinColumnSymbol2), new EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol)), - ImmutableList.of(leftJoinColumnSymbol, leftJoinColumnSymbol2, rightJoinColumnSymbol, rightJoinColumnSymbol2), + .join(INNER, pb.values(leftJoinColumn, leftJoinColumn2), + pb.values(rightJoinColumn, rightJoinColumn2), + ImmutableList.of(new EquiJoinClause(leftJoinColumn2, rightJoinColumn2), new EquiJoinClause(leftJoinColumn, rightJoinColumn)), + ImmutableList.of(leftJoinColumn, leftJoinColumn2, rightJoinColumn, rightJoinColumn2), Optional.of(leftJoinColumnLessThanTen)); }).withSourceStats(0, planNodeStats(LEFT_ROWS_COUNT, LEFT_JOIN_COLUMN_STATS, LEFT_JOIN_COLUMN_2_STATS)) .withSourceStats(1, planNodeStats(RIGHT_ROWS_COUNT, RIGHT_JOIN_COLUMN_STATS, RIGHT_JOIN_COLUMN_2_STATS)) @@ -195,11 +198,11 @@ public void testStatsForInnerJoinWithTwoEquiClausesAndNonEqualityFunction() public void testJoinComplementStats() { PlanNodeStatsEstimate expected = planNodeStats(LEFT_ROWS_COUNT * (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4), - symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, LEFT_JOIN_COLUMN_NULLS / (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4), 5), + variableStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, LEFT_JOIN_COLUMN_NULLS / (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4), 5), LEFT_OTHER_COLUMN_STATS); PlanNodeStatsEstimate actual = JOIN_STATS_RULE.calculateJoinComplementStats( Optional.empty(), - ImmutableList.of(new EquiJoinClause(new Symbol(LEFT_JOIN_COLUMN), new Symbol(RIGHT_JOIN_COLUMN))), + ImmutableList.of(new EquiJoinClause(LEFT_JOIN_COLUMN, RIGHT_JOIN_COLUMN)), LEFT_STATS, RIGHT_STATS, TYPES); @@ -212,12 +215,11 @@ public void testRightJoinComplementStats() PlanNodeStatsEstimate expected = NORMALIZER.normalize( planNodeStats( RIGHT_ROWS_COUNT * RIGHT_JOIN_COLUMN_NULLS, - symbolStatistics(RIGHT_JOIN_COLUMN, NaN, NaN, 1.0, 0), - RIGHT_OTHER_COLUMN_STATS), - TYPES); + variableStatistics(RIGHT_JOIN_COLUMN, NaN, NaN, 1.0, 0), + RIGHT_OTHER_COLUMN_STATS)); PlanNodeStatsEstimate actual = JOIN_STATS_RULE.calculateJoinComplementStats( Optional.empty(), - ImmutableList.of(new EquiJoinClause(new Symbol(RIGHT_JOIN_COLUMN), new Symbol(LEFT_JOIN_COLUMN))), + ImmutableList.of(new EquiJoinClause(RIGHT_JOIN_COLUMN, LEFT_JOIN_COLUMN)), RIGHT_STATS, LEFT_STATS, TYPES); @@ -227,7 +229,7 @@ public void testRightJoinComplementStats() @Test public void testLeftJoinComplementStatsWithNoClauses() { - PlanNodeStatsEstimate expected = NORMALIZER.normalize(LEFT_STATS.mapOutputRowCount(rowCount -> 0.0), TYPES); + PlanNodeStatsEstimate expected = NORMALIZER.normalize(LEFT_STATS.mapOutputRowCount(rowCount -> 0.0)); PlanNodeStatsEstimate actual = JOIN_STATS_RULE.calculateJoinComplementStats( Optional.empty(), ImmutableList.of(), @@ -242,12 +244,14 @@ public void testLeftJoinComplementStatsWithMultipleClauses() { PlanNodeStatsEstimate expected = planNodeStats( LEFT_ROWS_COUNT * (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4), - symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, LEFT_JOIN_COLUMN_NULLS / (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4), 5), + variableStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, LEFT_JOIN_COLUMN_NULLS / (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4), 5), LEFT_OTHER_COLUMN_STATS) .mapOutputRowCount(rowCount -> rowCount / UNKNOWN_FILTER_COEFFICIENT); PlanNodeStatsEstimate actual = JOIN_STATS_RULE.calculateJoinComplementStats( Optional.empty(), - ImmutableList.of(new EquiJoinClause(new Symbol(LEFT_JOIN_COLUMN), new Symbol(RIGHT_JOIN_COLUMN)), new EquiJoinClause(new Symbol(LEFT_OTHER_COLUMN), new Symbol(RIGHT_OTHER_COLUMN))), + ImmutableList.of( + new EquiJoinClause(LEFT_JOIN_COLUMN, RIGHT_JOIN_COLUMN), + new EquiJoinClause(LEFT_OTHER_COLUMN, RIGHT_OTHER_COLUMN)), LEFT_STATS, RIGHT_STATS, TYPES); @@ -264,10 +268,10 @@ public void testStatsForLeftAndRightJoin() PlanNodeStatsEstimate leftJoinStats = planNodeStats( totalRowCount, - symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, joinComplementColumnNulls * joinComplementRowCount / totalRowCount, LEFT_JOIN_COLUMN_NDV), + variableStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, joinComplementColumnNulls * joinComplementRowCount / totalRowCount, LEFT_JOIN_COLUMN_NDV), LEFT_OTHER_COLUMN_STATS, - symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, joinComplementRowCount / totalRowCount, RIGHT_JOIN_COLUMN_NDV), - symbolStatistics(RIGHT_OTHER_COLUMN, 24, 24, (0.24 * innerJoinRowCount + joinComplementRowCount) / totalRowCount, 1)); + variableStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, joinComplementRowCount / totalRowCount, RIGHT_JOIN_COLUMN_NDV), + variableStatistics(RIGHT_OTHER_COLUMN, 24, 24, (0.24 * innerJoinRowCount + joinComplementRowCount) / totalRowCount, 1)); assertJoinStats(LEFT, LEFT_STATS, RIGHT_STATS, leftJoinStats); assertJoinStats(RIGHT, RIGHT_JOIN_COLUMN, RIGHT_OTHER_COLUMN, LEFT_JOIN_COLUMN, LEFT_OTHER_COLUMN, RIGHT_STATS, LEFT_STATS, leftJoinStats); @@ -278,12 +282,12 @@ public void testLeftJoinMissingStats() { PlanNodeStatsEstimate leftStats = planNodeStats( 1, - new SymbolStatistics(LEFT_JOIN_COLUMN, SymbolStatsEstimate.unknown()), - new SymbolStatistics(LEFT_OTHER_COLUMN, SymbolStatsEstimate.unknown())); + new VariableStatistics(LEFT_JOIN_COLUMN, VariableStatsEstimate.unknown()), + new VariableStatistics(LEFT_OTHER_COLUMN, VariableStatsEstimate.unknown())); PlanNodeStatsEstimate rightStats = planNodeStats( 1, - new SymbolStatistics(RIGHT_JOIN_COLUMN, SymbolStatsEstimate.unknown()), - new SymbolStatistics(RIGHT_OTHER_COLUMN, SymbolStatsEstimate.unknown())); + new VariableStatistics(RIGHT_JOIN_COLUMN, VariableStatsEstimate.unknown()), + new VariableStatistics(RIGHT_OTHER_COLUMN, VariableStatsEstimate.unknown())); assertJoinStats(LEFT, leftStats, rightStats, PlanNodeStatsEstimate.unknown()); } @@ -299,10 +303,10 @@ public void testStatsForFullJoin() PlanNodeStatsEstimate leftJoinStats = planNodeStats( totalRowCount, - symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, (leftJoinComplementColumnNulls * leftJoinComplementRowCount + rightJoinComplementRowCount) / totalRowCount, LEFT_JOIN_COLUMN_NDV), - symbolStatistics(LEFT_OTHER_COLUMN, 42, 42, (0.42 * (innerJoinRowCount + leftJoinComplementRowCount) + rightJoinComplementRowCount) / totalRowCount, 1), - symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, (rightJoinComplementColumnNulls * rightJoinComplementRowCount + leftJoinComplementRowCount) / totalRowCount, RIGHT_JOIN_COLUMN_NDV), - symbolStatistics(RIGHT_OTHER_COLUMN, 24, 24, (0.24 * (innerJoinRowCount + rightJoinComplementRowCount) + leftJoinComplementRowCount) / totalRowCount, 1)); + variableStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, (leftJoinComplementColumnNulls * leftJoinComplementRowCount + rightJoinComplementRowCount) / totalRowCount, LEFT_JOIN_COLUMN_NDV), + variableStatistics(LEFT_OTHER_COLUMN, 42, 42, (0.42 * (innerJoinRowCount + leftJoinComplementRowCount) + rightJoinComplementRowCount) / totalRowCount, 1), + variableStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, (rightJoinComplementColumnNulls * rightJoinComplementRowCount + leftJoinComplementRowCount) / totalRowCount, RIGHT_JOIN_COLUMN_NDV), + variableStatistics(RIGHT_OTHER_COLUMN, 24, 24, (0.24 * (innerJoinRowCount + rightJoinComplementRowCount) + leftJoinComplementRowCount) / totalRowCount, 1)); assertJoinStats(FULL, LEFT_STATS, RIGHT_STATS, leftJoinStats); } @@ -312,11 +316,11 @@ public void testAddJoinComplementStats() { double statsToAddNdv = 5; PlanNodeStatsEstimate statsToAdd = planNodeStats(RIGHT_ROWS_COUNT, - symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 5.0, 0.2, statsToAddNdv)); + variableStatistics(LEFT_JOIN_COLUMN, 0.0, 5.0, 0.2, statsToAddNdv)); PlanNodeStatsEstimate addedStats = planNodeStats(TOTAL_ROWS_COUNT, - symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, (LEFT_ROWS_COUNT * LEFT_JOIN_COLUMN_NULLS + RIGHT_ROWS_COUNT * 0.2) / TOTAL_ROWS_COUNT, LEFT_JOIN_COLUMN_NDV), - symbolStatistics(LEFT_OTHER_COLUMN, 42, 42, (0.42 * LEFT_ROWS_COUNT + RIGHT_ROWS_COUNT) / TOTAL_ROWS_COUNT, 1)); + variableStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, (LEFT_ROWS_COUNT * LEFT_JOIN_COLUMN_NULLS + RIGHT_ROWS_COUNT * 0.2) / TOTAL_ROWS_COUNT, LEFT_JOIN_COLUMN_NDV), + variableStatistics(LEFT_OTHER_COLUMN, 42, 42, (0.42 * LEFT_ROWS_COUNT + RIGHT_ROWS_COUNT) / TOTAL_ROWS_COUNT, 1)); assertThat(JOIN_STATS_RULE.addJoinComplementStats( LEFT_STATS, @@ -330,37 +334,45 @@ private void assertJoinStats(JoinNode.Type joinType, PlanNodeStatsEstimate leftS assertJoinStats(joinType, LEFT_JOIN_COLUMN, LEFT_OTHER_COLUMN, RIGHT_JOIN_COLUMN, RIGHT_OTHER_COLUMN, leftStats, rightStats, resultStats); } - private void assertJoinStats(JoinNode.Type joinType, String leftJoinColumn, String leftOtherColumn, String rightJoinColumn, String rightOtherColumn, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate resultStats) + private void assertJoinStats( + JoinNode.Type joinType, + VariableReferenceExpression leftJoinColumn, + VariableReferenceExpression leftOtherColumn, + VariableReferenceExpression rightJoinColumn, + VariableReferenceExpression rightOtherColumn, + PlanNodeStatsEstimate leftStats, + PlanNodeStatsEstimate rightStats, + PlanNodeStatsEstimate resultStats) { tester().assertStatsFor(pb -> { - Symbol leftJoinColumnSymbol = pb.symbol(leftJoinColumn, BIGINT); - Symbol rightJoinColumnSymbol = pb.symbol(rightJoinColumn, DOUBLE); - Symbol leftOtherColumnSymbol = pb.symbol(leftOtherColumn, BIGINT); - Symbol rightOtherColumnSymbol = pb.symbol(rightOtherColumn, DOUBLE); + VariableReferenceExpression leftJoinColumnVariable = pb.variable(leftJoinColumn); + VariableReferenceExpression rightJoinColumnVariable = pb.variable(rightJoinColumn); + VariableReferenceExpression leftOtherColumnVariable = pb.variable(leftOtherColumn); + VariableReferenceExpression rightOtherColumnVariable = pb.variable(rightOtherColumn); return pb - .join(joinType, pb.values(leftJoinColumnSymbol, leftOtherColumnSymbol), - pb.values(rightJoinColumnSymbol, rightOtherColumnSymbol), - new EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol)); + .join(joinType, pb.values(leftJoinColumnVariable, leftOtherColumnVariable), + pb.values(rightJoinColumnVariable, rightOtherColumnVariable), + new EquiJoinClause(leftJoinColumnVariable, rightJoinColumnVariable)); }).withSourceStats(0, leftStats) .withSourceStats(1, rightStats) .check(JOIN_STATS_RULE, stats -> stats.equalTo(resultStats)); } - private static PlanNodeStatsEstimate planNodeStats(double rowCount, SymbolStatistics... symbolStatistics) + private static PlanNodeStatsEstimate planNodeStats(double rowCount, VariableStatistics... variableStatistics) { PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder() .setOutputRowCount(rowCount); - for (SymbolStatistics symbolStatistic : symbolStatistics) { - builder.addSymbolStatistics(symbolStatistic.symbol, symbolStatistic.estimate); + for (VariableStatistics symbolStatistic : variableStatistics) { + builder.addVariableStatistics(symbolStatistic.variable, symbolStatistic.estimate); } return builder.build(); } - private static SymbolStatistics symbolStatistics(String symbolName, double low, double high, double nullsFraction, double ndv) + private static VariableStatistics variableStatistics(VariableReferenceExpression variable, double low, double high, double nullsFraction, double ndv) { - return new SymbolStatistics( - new Symbol(symbolName), - SymbolStatsEstimate.builder() + return new VariableStatistics( + variable, + VariableStatsEstimate.builder() .setLowValue(low) .setHighValue(high) .setNullsFraction(nullsFraction) @@ -368,19 +380,19 @@ private static SymbolStatistics symbolStatistics(String symbolName, double low, .build()); } - private static class SymbolStatistics + private static class VariableStatistics { - final Symbol symbol; - final SymbolStatsEstimate estimate; + final VariableReferenceExpression variable; + final VariableStatsEstimate estimate; - SymbolStatistics(String symbolName, SymbolStatsEstimate estimate) + VariableStatistics(String variableName, VariableStatsEstimate estimate) { - this(new Symbol(symbolName), estimate); + this(new VariableReferenceExpression(variableName, BIGINT), estimate); } - SymbolStatistics(Symbol symbol, SymbolStatsEstimate estimate) + VariableStatistics(VariableReferenceExpression variable, VariableStatsEstimate estimate) { - this.symbol = symbol; + this.variable = variable; this.estimate = estimate; } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestOutputNodeStats.java b/presto-main/src/test/java/com/facebook/presto/cost/TestOutputNodeStats.java index 21e91384a3e14..3ecb8fa8b989e 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestOutputNodeStats.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestOutputNodeStats.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import org.testng.annotations.Test; import static com.facebook.presto.spi.type.BigintType.BIGINT; @@ -28,17 +28,17 @@ public void testStatsForOutputNode() { PlanNodeStatsEstimate stats = PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .addSymbolStatistics( - new Symbol("a"), - SymbolStatsEstimate.builder() + .addVariableStatistics( + new VariableReferenceExpression("a", BIGINT), + VariableStatsEstimate.builder() .setNullsFraction(0.3) .setLowValue(1) .setHighValue(30) .setDistinctValuesCount(20) .build()) - .addSymbolStatistics( - new Symbol("b"), - SymbolStatsEstimate.builder() + .addVariableStatistics( + new VariableReferenceExpression("b", DOUBLE), + VariableStatsEstimate.builder() .setNullsFraction(0.6) .setLowValue(13.5) .setHighValue(POSITIVE_INFINITY) @@ -48,8 +48,8 @@ public void testStatsForOutputNode() tester().assertStatsFor(pb -> pb .output(outputBuilder -> { - Symbol a = pb.symbol("a", BIGINT); - Symbol b = pb.symbol("b", DOUBLE); + VariableReferenceExpression a = pb.variable(pb.symbol("a", BIGINT)); + VariableReferenceExpression b = pb.variable(pb.symbol("b", DOUBLE)); outputBuilder .source(pb.values(a, b)) .column(a, "a1") diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestPlanNodeStatsEstimateMath.java b/presto-main/src/test/java/com/facebook/presto/cost/TestPlanNodeStatsEstimateMath.java index 2e2ae96f144dc..1ebdd2a0e662a 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestPlanNodeStatsEstimateMath.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestPlanNodeStatsEstimateMath.java @@ -13,13 +13,14 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import org.testng.annotations.Test; import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStatsAndMaxDistinctValues; import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues; import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.capStats; import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.subtractSubsetStats; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.NaN; @@ -27,7 +28,7 @@ public class TestPlanNodeStatsEstimateMath { - private static final Symbol SYMBOL = new Symbol("symbol"); + private static final VariableReferenceExpression VARIABLE = new VariableReferenceExpression("variable", BIGINT); private static final StatisticRange NON_EMPTY_RANGE = openRange(1); @Test @@ -64,7 +65,7 @@ public void testAddNullsFraction() private static void assertAddNullsFraction(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(addStatsAndSumDistinctValues(first, second).getSymbolStatistics(SYMBOL).getNullsFraction(), expected); + assertEquals(addStatsAndSumDistinctValues(first, second).getVariableStatistics(VARIABLE).getNullsFraction(), expected); } @Test @@ -90,7 +91,7 @@ public void testAddAverageRowSize() private static void assertAddAverageRowSize(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(addStatsAndSumDistinctValues(first, second).getSymbolStatistics(SYMBOL).getAverageRowSize(), expected); + assertEquals(addStatsAndSumDistinctValues(first, second).getVariableStatistics(VARIABLE).getAverageRowSize(), expected); } @Test @@ -111,7 +112,7 @@ public void testSumNumberOfDistinctValues() private static void assertSumNumberOfDistinctValues(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(addStatsAndSumDistinctValues(first, second).getSymbolStatistics(SYMBOL).getDistinctValuesCount(), expected); + assertEquals(addStatsAndSumDistinctValues(first, second).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); } @Test @@ -132,7 +133,7 @@ public void testMaxNumberOfDistinctValues() private static void assertMaxNumberOfDistinctValues(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(addStatsAndMaxDistinctValues(first, second).getSymbolStatistics(SYMBOL).getDistinctValuesCount(), expected); + assertEquals(addStatsAndMaxDistinctValues(first, second).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); } @Test @@ -153,7 +154,7 @@ public void testAddRange() private static void assertAddRange(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expectedLow, double expectedHigh) { - SymbolStatsEstimate statistics = addStatsAndMaxDistinctValues(first, second).getSymbolStatistics(SYMBOL); + VariableStatsEstimate statistics = addStatsAndMaxDistinctValues(first, second).getVariableStatistics(VARIABLE); assertEquals(statistics.getLowValue(), expectedLow); assertEquals(statistics.getHighValue(), expectedHigh); } @@ -191,7 +192,7 @@ public void testSubtractNullsFraction() private static void assertSubtractNullsFraction(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(subtractSubsetStats(first, second).getSymbolStatistics(SYMBOL).getNullsFraction(), expected); + assertEquals(subtractSubsetStats(first, second).getVariableStatistics(VARIABLE).getNullsFraction(), expected); } @Test @@ -215,7 +216,7 @@ public void testSubtractNumberOfDistinctValues() private static void assertSubtractNumberOfDistinctValues(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(subtractSubsetStats(first, second).getSymbolStatistics(SYMBOL).getDistinctValuesCount(), expected); + assertEquals(subtractSubsetStats(first, second).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); } @Test @@ -234,7 +235,7 @@ private static void assertSubtractRange(double supersetLow, double supersetHigh, { PlanNodeStatsEstimate first = statistics(30, NaN, NaN, new StatisticRange(supersetLow, supersetHigh, 10)); PlanNodeStatsEstimate second = statistics(20, NaN, NaN, new StatisticRange(subsetLow, subsetHigh, 5)); - SymbolStatsEstimate statistics = subtractSubsetStats(first, second).getSymbolStatistics(SYMBOL); + VariableStatsEstimate statistics = subtractSubsetStats(first, second).getVariableStatistics(VARIABLE); assertEquals(statistics.getLowValue(), expectedLow); assertEquals(statistics.getHighValue(), expectedHigh); } @@ -272,7 +273,7 @@ public void testCapAverageRowSize() private static void assertCapAverageRowSize(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expected) { - assertEquals(capStats(stats, cap).getSymbolStatistics(SYMBOL).getAverageRowSize(), expected); + assertEquals(capStats(stats, cap).getVariableStatistics(VARIABLE).getAverageRowSize(), expected); } @Test @@ -292,7 +293,7 @@ public void testCapNumberOfDistinctValues() private static void assertCapNumberOfDistinctValues(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expected) { - assertEquals(capStats(stats, cap).getSymbolStatistics(SYMBOL).getDistinctValuesCount(), expected); + assertEquals(capStats(stats, cap).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); } @Test @@ -313,7 +314,7 @@ public void testCapRange() private static void assertCapRange(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expectedLow, double expectedHigh) { - SymbolStatsEstimate symbolStats = capStats(stats, cap).getSymbolStatistics(SYMBOL); + VariableStatsEstimate symbolStats = capStats(stats, cap).getVariableStatistics(VARIABLE); assertEquals(symbolStats.getLowValue(), expectedLow); assertEquals(symbolStats.getHighValue(), expectedHigh); } @@ -337,14 +338,14 @@ public void testCapNullsFraction() private static void assertCapNullsFraction(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expected) { - assertEquals(capStats(stats, cap).getSymbolStatistics(SYMBOL).getNullsFraction(), expected); + assertEquals(capStats(stats, cap).getVariableStatistics(VARIABLE).getNullsFraction(), expected); } private static PlanNodeStatsEstimate statistics(double rowCount, double nullsFraction, double averageRowSize, StatisticRange range) { return PlanNodeStatsEstimate.builder() .setOutputRowCount(rowCount) - .addSymbolStatistics(SYMBOL, SymbolStatsEstimate.builder() + .addVariableStatistics(VARIABLE, VariableStatsEstimate.builder() .setNullsFraction(nullsFraction) .setAverageRowSize(averageRowSize) .setStatisticsRange(range) diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestRowNumberStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestRowNumberStatsRule.java index 8f00684a5ce2c..705f31c7d54f5 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestRowNumberStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestRowNumberStatsRule.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -24,11 +24,11 @@ public class TestRowNumberStatsRule extends BaseStatsCalculatorTest { - private SymbolStatsEstimate xStats = SymbolStatsEstimate.builder() + private VariableStatsEstimate xStats = VariableStatsEstimate.builder() .setDistinctValuesCount(5.0) .setNullsFraction(0) .build(); - private SymbolStatsEstimate yStats = SymbolStatsEstimate.builder() + private VariableStatsEstimate yStats = VariableStatsEstimate.builder() .setDistinctValuesCount(5.0) .setNullsFraction(0.5) .build(); @@ -39,20 +39,21 @@ public void testSingleGroupingKey() // grouping on a key with 0 nulls fraction without max rows per partition limit tester().assertStatsFor(pb -> pb .rowNumber( - ImmutableList.of(pb.symbol("x", BIGINT)), + ImmutableList.of(pb.variable("x", BIGINT)), Optional.empty(), pb.symbol("z", BIGINT), - pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) + pb.variable("z", BIGINT), + pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(new Symbol("x"), xStats) - .addSymbolStatistics(new Symbol("y"), yStats) + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), xStats) + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), yStats) .build()) .check(check -> check .outputRowsCount(10) - .symbolStats("x", assertion -> assertion.isEqualTo(xStats)) - .symbolStats("y", assertion -> assertion.isEqualTo(yStats)) - .symbolStats("z", assertion -> assertion + .variableStats(new VariableReferenceExpression("x", BIGINT), assertion -> assertion.isEqualTo(xStats)) + .variableStats(new VariableReferenceExpression("y", BIGINT), assertion -> assertion.isEqualTo(yStats)) + .variableStats(new VariableReferenceExpression("z", BIGINT), assertion -> assertion .lowValue(1) .distinctValuesCount(2) .nullsFraction(0) @@ -61,18 +62,19 @@ public void testSingleGroupingKey() // grouping on a key with 0 nulls fraction with max rows per partition limit tester().assertStatsFor(pb -> pb .rowNumber( - ImmutableList.of(pb.symbol("x", BIGINT)), + ImmutableList.of(pb.variable(pb.symbol("x", BIGINT))), Optional.of(1), pb.symbol("z", BIGINT), - pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) + pb.variable("z", BIGINT), + pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(new Symbol("x"), xStats) - .addSymbolStatistics(new Symbol("y"), yStats) + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), xStats) + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), yStats) .build()) .check(check -> check .outputRowsCount(5) - .symbolStats("z", assertion -> assertion + .variableStats(new VariableReferenceExpression("z", BIGINT), assertion -> assertion .lowValue(1) .distinctValuesCount(1) .nullsFraction(0) @@ -81,18 +83,19 @@ public void testSingleGroupingKey() // grouping on a key with non zero nulls fraction tester().assertStatsFor(pb -> pb .rowNumber( - ImmutableList.of(pb.symbol("y", BIGINT)), + ImmutableList.of(pb.variable(pb.symbol("y", BIGINT))), Optional.empty(), pb.symbol("z", BIGINT), - pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) + pb.variable("z", BIGINT), + pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(60) - .addSymbolStatistics(new Symbol("x"), xStats) - .addSymbolStatistics(new Symbol("y"), yStats) + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), xStats) + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), yStats) .build()) .check(check -> check .outputRowsCount(60) - .symbolStats("z", assertion -> assertion + .variableStats(new VariableReferenceExpression("z", BIGINT), assertion -> assertion .lowValue(1) .distinctValuesCount(10) .nullsFraction(0) @@ -101,13 +104,14 @@ public void testSingleGroupingKey() // unknown input row count tester().assertStatsFor(pb -> pb .rowNumber( - ImmutableList.of(pb.symbol("x", BIGINT)), + ImmutableList.of(pb.variable("x", BIGINT)), Optional.of(1), pb.symbol("z", BIGINT), - pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) + pb.variable("z", BIGINT), + pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("x"), xStats) - .addSymbolStatistics(new Symbol("y"), yStats) + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), xStats) + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), yStats) .build()) .check(PlanNodeStatsAssertion::outputRowsCountUnknown); } @@ -118,18 +122,19 @@ public void testMultipleGroupingKeys() // grouping on multiple keys with the number of estimated groups less than the row count tester().assertStatsFor(pb -> pb .rowNumber( - ImmutableList.of(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)), + ImmutableList.of(pb.variable("x", BIGINT), pb.variable("y", BIGINT)), Optional.empty(), pb.symbol("z", BIGINT), - pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) + pb.variable("z", BIGINT), + pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(60) - .addSymbolStatistics(new Symbol("x"), xStats) - .addSymbolStatistics(new Symbol("y"), yStats) + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), xStats) + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), yStats) .build()) .check(check -> check .outputRowsCount(60) - .symbolStats("z", assertion -> assertion + .variableStats(new VariableReferenceExpression("z", BIGINT), assertion -> assertion .lowValue(1) .distinctValuesCount(2) .nullsFraction(0) @@ -138,18 +143,19 @@ public void testMultipleGroupingKeys() // grouping on multiple keys with the number of estimated groups greater than the row count tester().assertStatsFor(pb -> pb .rowNumber( - ImmutableList.of(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)), + ImmutableList.of(pb.variable("x", BIGINT), pb.variable("y", BIGINT)), Optional.empty(), pb.symbol("z", BIGINT), - pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) + pb.variable("z", BIGINT), + pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(20) - .addSymbolStatistics(new Symbol("x"), xStats) - .addSymbolStatistics(new Symbol("y"), yStats) + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), xStats) + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), yStats) .build()) .check(check -> check .outputRowsCount(20) - .symbolStats("z", assertion -> assertion + .variableStats(new VariableReferenceExpression("z", BIGINT), assertion -> assertion .lowValue(1) .distinctValuesCount(1) .nullsFraction(0) @@ -158,14 +164,15 @@ public void testMultipleGroupingKeys() // grouping on multiple keys with stats for one of the keys are unknown tester().assertStatsFor(pb -> pb .rowNumber( - ImmutableList.of(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)), + ImmutableList.of(pb.variable("x", BIGINT), pb.variable("y", BIGINT)), Optional.empty(), pb.symbol("z", BIGINT), - pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) + pb.variable("z", BIGINT), + pb.values(pb.variable("x", BIGINT), pb.variable("y", BIGINT)))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(20) - .addSymbolStatistics(new Symbol("x"), xStats) - .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.unknown()) + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), xStats) + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), VariableStatsEstimate.unknown()) .build()) .check(PlanNodeStatsAssertion::outputRowsCountUnknown); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java index b2e086d605ca8..669e71e5a046d 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.parser.SqlParser; @@ -169,7 +170,7 @@ public void testVarbinaryConstant() @Test public void testSymbolReference() { - SymbolStatsEstimate xStats = SymbolStatsEstimate.builder() + VariableStatsEstimate xStats = VariableStatsEstimate.builder() .setLowValue(-1) .setHighValue(10) .setDistinctValuesCount(4) @@ -177,18 +178,18 @@ public void testSymbolReference() .setAverageRowSize(2.0) .build(); PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("x"), xStats) + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), xStats) .build(); assertCalculate(expression("x"), inputStatistics).isEqualTo(xStats); - assertCalculate(expression("y"), inputStatistics).isEqualTo(SymbolStatsEstimate.unknown()); + assertCalculate(expression("y"), inputStatistics).isEqualTo(VariableStatsEstimate.unknown()); } @Test public void testCastDoubleToBigint() { PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), VariableStatsEstimate.builder() .setNullsFraction(0.3) .setLowValue(1.6) .setHighValue(17.3) @@ -209,7 +210,7 @@ public void testCastDoubleToBigint() public void testCastDoubleToShortRange() { PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), VariableStatsEstimate.builder() .setNullsFraction(0.3) .setLowValue(1.6) .setHighValue(3.3) @@ -230,7 +231,7 @@ public void testCastDoubleToShortRange() public void testCastDoubleToShortRangeUnknownDistinctValuesCount() { PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), VariableStatsEstimate.builder() .setNullsFraction(0.3) .setLowValue(1.6) .setHighValue(3.3) @@ -250,7 +251,7 @@ public void testCastDoubleToShortRangeUnknownDistinctValuesCount() public void testCastBigintToDouble() { PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("a", DOUBLE), VariableStatsEstimate.builder() .setNullsFraction(0.3) .setLowValue(2.0) .setHighValue(10.0) @@ -278,38 +279,38 @@ public void testCastUnknown() .dataSizeUnknown(); } - private SymbolStatsAssertion assertCalculate(Expression scalarExpression) + private VariableStatsAssertion assertCalculate(Expression scalarExpression) { return assertCalculate(scalarExpression, PlanNodeStatsEstimate.unknown()); } - private SymbolStatsAssertion assertCalculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics) + private VariableStatsAssertion assertCalculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics) { return assertCalculate(scalarExpression, inputStatistics, TypeProvider.copyOf(DEFAULT_SYMBOL_TYPES)); } - private SymbolStatsAssertion assertCalculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, TypeProvider types) + private VariableStatsAssertion assertCalculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, TypeProvider types) { // assert both visitors yield the same result RowExpression scalarRowExpression = translator.translateAndOptimize(scalarExpression, types); - SymbolStatsEstimate expressionSymbolStatsEstimate = calculator.calculate(scalarExpression, inputStatistics, session, types); - SymbolStatsEstimate rowExpressionSymbolStatsEstimate = calculator.calculate(scalarRowExpression, inputStatistics, session); - assertEquals(expressionSymbolStatsEstimate, rowExpressionSymbolStatsEstimate); - return SymbolStatsAssertion.assertThat(expressionSymbolStatsEstimate); + VariableStatsEstimate expressionVariableStatsEstimate = calculator.calculate(scalarExpression, inputStatistics, session, types); + VariableStatsEstimate rowExpressionVariableStatsEstimate = calculator.calculate(scalarRowExpression, inputStatistics, session); + assertEquals(expressionVariableStatsEstimate, rowExpressionVariableStatsEstimate); + return VariableStatsAssertion.assertThat(expressionVariableStatsEstimate); } @Test public void testNonDivideArithmeticBinaryExpression() { PlanNodeStatsEstimate relationStats = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), VariableStatsEstimate.builder() .setLowValue(-1) .setHighValue(10) .setDistinctValuesCount(4) .setNullsFraction(0.1) .setAverageRowSize(2.0) .build()) - .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), VariableStatsEstimate.builder() .setLowValue(-2) .setHighValue(5) .setDistinctValuesCount(3) @@ -345,7 +346,7 @@ public void testNonDivideArithmeticBinaryExpression() public void tesArithmeticUnaryExpression() { PlanNodeStatsEstimate relationStats = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), VariableStatsEstimate.builder() .setLowValue(-1) .setHighValue(10) .setDistinctValuesCount(4) @@ -373,16 +374,16 @@ public void tesArithmeticUnaryExpression() @Test public void testArithmeticBinaryWithAllNullsSymbol() { - SymbolStatsEstimate allNullStats = SymbolStatsEstimate.zero(); + VariableStatsEstimate allNullStats = VariableStatsEstimate.zero(); PlanNodeStatsEstimate relationStats = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), VariableStatsEstimate.builder() .setLowValue(-1) .setHighValue(10) .setDistinctValuesCount(4) .setNullsFraction(0.1) .setAverageRowSize(0) .build()) - .addSymbolStatistics(new Symbol("all_null"), allNullStats) + .addVariableStatistics(new VariableReferenceExpression("all_null", BIGINT), allNullStats) .setOutputRowCount(10) .build(); @@ -480,11 +481,11 @@ public void testModulusArithmeticBinaryExpression() private PlanNodeStatsEstimate xyStats(double lowX, double highX, double lowY, double highY) { return PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), VariableStatsEstimate.builder() .setLowValue(lowX) .setHighValue(highX) .build()) - .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), VariableStatsEstimate.builder() .setLowValue(lowY) .setHighValue(highY) .build()) @@ -495,14 +496,14 @@ private PlanNodeStatsEstimate xyStats(double lowX, double highX, double lowY, do public void testCoalesceExpression() { PlanNodeStatsEstimate relationStats = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("x", BIGINT), VariableStatsEstimate.builder() .setLowValue(-1) .setHighValue(10) .setDistinctValuesCount(4) .setNullsFraction(0.1) .setAverageRowSize(2.0) .build()) - .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression("y", BIGINT), VariableStatsEstimate.builder() .setLowValue(-2) .setHighValue(5) .setDistinctValuesCount(3) diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestSemiJoinStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestSemiJoinStatsCalculator.java index d077654c7539f..a2c7414ff9716 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestSemiJoinStatsCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestSemiJoinStatsCalculator.java @@ -14,13 +14,14 @@ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import static com.facebook.presto.cost.PlanNodeStatsAssertion.assertThat; import static com.facebook.presto.cost.SemiJoinStatsCalculator.computeAntiJoin; import static com.facebook.presto.cost.SemiJoinStatsCalculator.computeSemiJoin; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.NaN; import static java.lang.Double.POSITIVE_INFINITY; @@ -28,113 +29,113 @@ public class TestSemiJoinStatsCalculator { private PlanNodeStatsEstimate inputStatistics; - private SymbolStatsEstimate uStats; - private SymbolStatsEstimate wStats; - private SymbolStatsEstimate xStats; - private SymbolStatsEstimate yStats; - private SymbolStatsEstimate zStats; - private SymbolStatsEstimate leftOpenStats; - private SymbolStatsEstimate rightOpenStats; - private SymbolStatsEstimate unknownRangeStats; - private SymbolStatsEstimate emptyRangeStats; - private SymbolStatsEstimate fractionalNdvStats; + private VariableStatsEstimate uStats; + private VariableStatsEstimate wStats; + private VariableStatsEstimate xStats; + private VariableStatsEstimate yStats; + private VariableStatsEstimate zStats; + private VariableStatsEstimate leftOpenStats; + private VariableStatsEstimate rightOpenStats; + private VariableStatsEstimate unknownRangeStats; + private VariableStatsEstimate emptyRangeStats; + private VariableStatsEstimate fractionalNdvStats; - private Symbol u = new Symbol("u"); - private Symbol w = new Symbol("w"); - private Symbol x = new Symbol("x"); - private Symbol y = new Symbol("y"); - private Symbol z = new Symbol("z"); - private Symbol leftOpen = new Symbol("leftOpen"); - private Symbol rightOpen = new Symbol("rightOpen"); - private Symbol unknownRange = new Symbol("unknownRange"); - private Symbol emptyRange = new Symbol("emptyRange"); - private Symbol unknown = new Symbol("unknown"); - private Symbol fractionalNdv = new Symbol("fractionalNdv"); + private VariableReferenceExpression u = new VariableReferenceExpression("u", BIGINT); + private VariableReferenceExpression w = new VariableReferenceExpression("w", BIGINT); + private VariableReferenceExpression x = new VariableReferenceExpression("x", BIGINT); + private VariableReferenceExpression y = new VariableReferenceExpression("y", BIGINT); + private VariableReferenceExpression z = new VariableReferenceExpression("z", BIGINT); + private VariableReferenceExpression leftOpen = new VariableReferenceExpression("leftOpen", BIGINT); + private VariableReferenceExpression rightOpen = new VariableReferenceExpression("rightOpen", BIGINT); + private VariableReferenceExpression unknownRange = new VariableReferenceExpression("unknownRange", BIGINT); + private VariableReferenceExpression emptyRange = new VariableReferenceExpression("emptyRange", BIGINT); + private VariableReferenceExpression unknown = new VariableReferenceExpression("unknown", BIGINT); + private VariableReferenceExpression fractionalNdv = new VariableReferenceExpression("fractionalNdv", BIGINT); @BeforeMethod public void setUp() throws Exception { - uStats = SymbolStatsEstimate.builder() + uStats = VariableStatsEstimate.builder() .setAverageRowSize(8.0) .setDistinctValuesCount(300) .setLowValue(0) .setHighValue(20) .setNullsFraction(0.1) .build(); - wStats = SymbolStatsEstimate.builder() + wStats = VariableStatsEstimate.builder() .setAverageRowSize(8.0) .setDistinctValuesCount(30) .setLowValue(0) .setHighValue(20) .setNullsFraction(0.1) .build(); - xStats = SymbolStatsEstimate.builder() + xStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(40.0) .setLowValue(-10.0) .setHighValue(10.0) .setNullsFraction(0.25) .build(); - yStats = SymbolStatsEstimate.builder() + yStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(20.0) .setLowValue(0.0) .setHighValue(5.0) .setNullsFraction(0.5) .build(); - zStats = SymbolStatsEstimate.builder() + zStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(5.0) .setLowValue(-100.0) .setHighValue(100.0) .setNullsFraction(0.1) .build(); - leftOpenStats = SymbolStatsEstimate.builder() + leftOpenStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(NEGATIVE_INFINITY) .setHighValue(15.0) .setNullsFraction(0.1) .build(); - rightOpenStats = SymbolStatsEstimate.builder() + rightOpenStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(-15.0) .setHighValue(POSITIVE_INFINITY) .setNullsFraction(0.1) .build(); - unknownRangeStats = SymbolStatsEstimate.builder() + unknownRangeStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) .setLowValue(NEGATIVE_INFINITY) .setHighValue(POSITIVE_INFINITY) .setNullsFraction(0.1) .build(); - emptyRangeStats = SymbolStatsEstimate.builder() + emptyRangeStats = VariableStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(0.0) .setLowValue(NaN) .setHighValue(NaN) .setNullsFraction(NaN) .build(); - fractionalNdvStats = SymbolStatsEstimate.builder() + fractionalNdvStats = VariableStatsEstimate.builder() .setAverageRowSize(NaN) .setDistinctValuesCount(0.1) .setNullsFraction(0) .build(); inputStatistics = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(u, uStats) - .addSymbolStatistics(w, wStats) - .addSymbolStatistics(x, xStats) - .addSymbolStatistics(y, yStats) - .addSymbolStatistics(z, zStats) - .addSymbolStatistics(leftOpen, leftOpenStats) - .addSymbolStatistics(rightOpen, rightOpenStats) - .addSymbolStatistics(unknownRange, unknownRangeStats) - .addSymbolStatistics(emptyRange, emptyRangeStats) - .addSymbolStatistics(unknown, SymbolStatsEstimate.unknown()) - .addSymbolStatistics(fractionalNdv, fractionalNdvStats) + .addVariableStatistics(u, uStats) + .addVariableStatistics(w, wStats) + .addVariableStatistics(x, xStats) + .addVariableStatistics(y, yStats) + .addVariableStatistics(z, zStats) + .addVariableStatistics(leftOpen, leftOpenStats) + .addVariableStatistics(rightOpen, rightOpenStats) + .addVariableStatistics(unknownRange, unknownRangeStats) + .addVariableStatistics(emptyRange, emptyRangeStats) + .addVariableStatistics(unknown, VariableStatsEstimate.unknown()) + .addVariableStatistics(fractionalNdv, fractionalNdvStats) .setOutputRowCount(1000.0) .build(); } @@ -144,45 +145,45 @@ public void testSemiJoin() { // overlapping ranges assertThat(computeSemiJoin(inputStatistics, inputStatistics, x, w)) - .symbolStats(x, stats -> stats + .variableStats(x, stats -> stats .lowValue(xStats.getLowValue()) .highValue(xStats.getHighValue()) .nullsFraction(0) .distinctValuesCount(wStats.getDistinctValuesCount())) - .symbolStats(w, stats -> stats.isEqualTo(wStats)) - .symbolStats(z, stats -> stats.isEqualTo(zStats)) + .variableStats(w, stats -> stats.isEqualTo(wStats)) + .variableStats(z, stats -> stats.isEqualTo(zStats)) .outputRowsCount(inputStatistics.getOutputRowCount() * xStats.getValuesFraction() * (wStats.getDistinctValuesCount() / xStats.getDistinctValuesCount())); // overlapping ranges, nothing filtered out assertThat(computeSemiJoin(inputStatistics, inputStatistics, x, u)) - .symbolStats(x, stats -> stats + .variableStats(x, stats -> stats .lowValue(xStats.getLowValue()) .highValue(xStats.getHighValue()) .nullsFraction(0) .distinctValuesCount(xStats.getDistinctValuesCount())) - .symbolStats(u, stats -> stats.isEqualTo(uStats)) - .symbolStats(z, stats -> stats.isEqualTo(zStats)) + .variableStats(u, stats -> stats.isEqualTo(uStats)) + .variableStats(z, stats -> stats.isEqualTo(zStats)) .outputRowsCount(inputStatistics.getOutputRowCount() * xStats.getValuesFraction()); // source stats are unknown assertThat(computeSemiJoin(inputStatistics, inputStatistics, unknown, u)) - .symbolStats(unknown, stats -> stats + .variableStats(unknown, stats -> stats .nullsFraction(0) .distinctValuesCountUnknown() .unknownRange()) - .symbolStats(u, stats -> stats.isEqualTo(uStats)) - .symbolStats(z, stats -> stats.isEqualTo(zStats)) + .variableStats(u, stats -> stats.isEqualTo(uStats)) + .variableStats(z, stats -> stats.isEqualTo(zStats)) .outputRowsCountUnknown(); // filtering stats are unknown assertThat(computeSemiJoin(inputStatistics, inputStatistics, x, unknown)) - .symbolStats(x, stats -> stats + .variableStats(x, stats -> stats .nullsFraction(0) .lowValue(xStats.getLowValue()) .highValue(xStats.getHighValue()) .distinctValuesCountUnknown()) - .symbolStatsUnknown(unknown) - .symbolStats(z, stats -> stats.isEqualTo(zStats)) + .variableStatsUnknown(unknown) + .variableStats(z, stats -> stats.isEqualTo(zStats)) .outputRowsCountUnknown(); // zero distinct values @@ -192,7 +193,7 @@ public void testSemiJoin() // fractional distinct values assertThat(computeSemiJoin(inputStatistics, inputStatistics, fractionalNdv, fractionalNdv)) .outputRowsCount(1000) - .symbolStats(fractionalNdv, stats -> stats + .variableStats(fractionalNdv, stats -> stats .nullsFraction(0) .distinctValuesCount(0.1)); } @@ -202,45 +203,45 @@ public void testAntiJoin() { // overlapping ranges assertThat(computeAntiJoin(inputStatistics, inputStatistics, u, x)) - .symbolStats(u, stats -> stats + .variableStats(u, stats -> stats .lowValue(uStats.getLowValue()) .highValue(uStats.getHighValue()) .nullsFraction(0) .distinctValuesCount(uStats.getDistinctValuesCount() - xStats.getDistinctValuesCount())) - .symbolStats(x, stats -> stats.isEqualTo(xStats)) - .symbolStats(z, stats -> stats.isEqualTo(zStats)) + .variableStats(x, stats -> stats.isEqualTo(xStats)) + .variableStats(z, stats -> stats.isEqualTo(zStats)) .outputRowsCount(inputStatistics.getOutputRowCount() * uStats.getValuesFraction() * (1 - xStats.getDistinctValuesCount() / uStats.getDistinctValuesCount())); // overlapping ranges, everything filtered out (but we leave 0.5 due to safety coeeficient) assertThat(computeAntiJoin(inputStatistics, inputStatistics, x, u)) - .symbolStats(x, stats -> stats + .variableStats(x, stats -> stats .lowValue(xStats.getLowValue()) .highValue(xStats.getHighValue()) .nullsFraction(0) .distinctValuesCount(xStats.getDistinctValuesCount() * 0.5)) - .symbolStats(u, stats -> stats.isEqualTo(uStats)) - .symbolStats(z, stats -> stats.isEqualTo(zStats)) + .variableStats(u, stats -> stats.isEqualTo(uStats)) + .variableStats(z, stats -> stats.isEqualTo(zStats)) .outputRowsCount(inputStatistics.getOutputRowCount() * xStats.getValuesFraction() * 0.5); // source stats are unknown assertThat(computeAntiJoin(inputStatistics, inputStatistics, unknown, u)) - .symbolStats(unknown, stats -> stats + .variableStats(unknown, stats -> stats .nullsFraction(0) .distinctValuesCountUnknown() .unknownRange()) - .symbolStats(u, stats -> stats.isEqualTo(uStats)) - .symbolStats(z, stats -> stats.isEqualTo(zStats)) + .variableStats(u, stats -> stats.isEqualTo(uStats)) + .variableStats(z, stats -> stats.isEqualTo(zStats)) .outputRowsCountUnknown(); // filtering stats are unknown assertThat(computeAntiJoin(inputStatistics, inputStatistics, x, unknown)) - .symbolStats(x, stats -> stats + .variableStats(x, stats -> stats .nullsFraction(0) .lowValue(xStats.getLowValue()) .highValue(xStats.getHighValue()) .distinctValuesCountUnknown()) - .symbolStatsUnknown(unknown) - .symbolStats(z, stats -> stats.isEqualTo(zStats)) + .variableStatsUnknown(unknown) + .variableStats(z, stats -> stats.isEqualTo(zStats)) .outputRowsCountUnknown(); // zero distinct values @@ -250,7 +251,7 @@ public void testAntiJoin() // fractional distinct values assertThat(computeAntiJoin(inputStatistics, inputStatistics, fractionalNdv, fractionalNdv)) .outputRowsCount(500) - .symbolStats(fractionalNdv, stats -> stats + .variableStats(fractionalNdv, stats -> stats .nullsFraction(0) .distinctValuesCount(0.05)); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestSemiJoinStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestSemiJoinStatsRule.java index 022b28180b859..30791ef8e4482 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestSemiJoinStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestSemiJoinStatsRule.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import org.testng.annotations.Test; import java.util.Optional; @@ -27,7 +27,7 @@ public class TestSemiJoinStatsRule @Test public void testSemiJoinPropagatesSourceStats() { - SymbolStatsEstimate stats = SymbolStatsEstimate.builder() + VariableStatsEstimate stats = VariableStatsEstimate.builder() .setLowValue(1) .setHighValue(10) .setDistinctValuesCount(5) @@ -35,10 +35,10 @@ public void testSemiJoinPropagatesSourceStats() .build(); tester().assertStatsFor(pb -> { - Symbol a = pb.symbol("a", BIGINT); - Symbol b = pb.symbol("b", BIGINT); - Symbol c = pb.symbol("c", BIGINT); - Symbol semiJoinOutput = pb.symbol("sjo", BOOLEAN); + VariableReferenceExpression a = pb.variable("a", BIGINT); + VariableReferenceExpression b = pb.variable("b", BIGINT); + VariableReferenceExpression c = pb.variable("c", BIGINT); + VariableReferenceExpression semiJoinOutput = pb.variable("sjo", BOOLEAN); return pb .semiJoin(pb.values(a, b), pb.values(c), @@ -51,18 +51,18 @@ public void testSemiJoinPropagatesSourceStats() }) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(new Symbol("a"), stats) - .addSymbolStatistics(new Symbol("b"), stats) + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), stats) + .addVariableStatistics(new VariableReferenceExpression("b", BIGINT), stats) .build()) .withSourceStats(1, PlanNodeStatsEstimate.builder() .setOutputRowCount(20) - .addSymbolStatistics(new Symbol("c"), stats) + .addVariableStatistics(new VariableReferenceExpression("c", BIGINT), stats) .build()) .check(check -> check .outputRowsCount(10) - .symbolStats("a", assertion -> assertion.isEqualTo(stats)) - .symbolStats("b", assertion -> assertion.isEqualTo(stats)) - .symbolStatsUnknown("c") - .symbolStatsUnknown("sjo")); + .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(stats)) + .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(stats)) + .variableStatsUnknown("c") + .variableStatsUnknown("sjo")); } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java index 1f85cbf8ff025..0a903b0f58bcb 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java @@ -15,6 +15,7 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.Assignments; @@ -32,42 +33,42 @@ public class TestSimpleFilterProjectSemiJoinStatsRule extends BaseStatsCalculatorTest { - private SymbolStatsEstimate aStats = SymbolStatsEstimate.builder() + private VariableStatsEstimate aStats = VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(10) .setDistinctValuesCount(10) .setNullsFraction(0.1) .build(); - private SymbolStatsEstimate bStats = SymbolStatsEstimate.builder() + private VariableStatsEstimate bStats = VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(100) .setDistinctValuesCount(10) .setNullsFraction(0) .build(); - private SymbolStatsEstimate cStats = SymbolStatsEstimate.builder() + private VariableStatsEstimate cStats = VariableStatsEstimate.builder() .setLowValue(5) .setHighValue(30) .setDistinctValuesCount(2) .setNullsFraction(0.5) .build(); - private SymbolStatsEstimate expectedAInC = SymbolStatsEstimate.builder() + private VariableStatsEstimate expectedAInC = VariableStatsEstimate.builder() .setDistinctValuesCount(2) .setLowValue(0) .setHighValue(10) .setNullsFraction(0) .build(); - private SymbolStatsEstimate expectedANotInC = SymbolStatsEstimate.builder() + private VariableStatsEstimate expectedANotInC = VariableStatsEstimate.builder() .setDistinctValuesCount(1.6) .setLowValue(0) .setHighValue(8) .setNullsFraction(0) .build(); - private SymbolStatsEstimate expectedANotInCWithExtraFilter = SymbolStatsEstimate.builder() + private VariableStatsEstimate expectedANotInCWithExtraFilter = VariableStatsEstimate.builder() .setDistinctValuesCount(8) .setLowValue(0) .setHighValue(10) @@ -90,28 +91,28 @@ public void testFilterPositiveSemiJoin(boolean toRowExpression) getStatsCalculatorAssertion(new Symbol("sjo").toSymbolReference(), toRowExpression) .withSourceStats(LEFT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(1000) - .addSymbolStatistics(new Symbol("a"), aStats) - .addSymbolStatistics(new Symbol("b"), bStats) + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), aStats) + .addVariableStatistics(new VariableReferenceExpression("b", BIGINT), bStats) .build()) .withSourceStats(RIGHT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(2000) - .addSymbolStatistics(new Symbol("c"), cStats) + .addVariableStatistics(new VariableReferenceExpression("c", BIGINT), cStats) .build()) .check(check -> check.outputRowsCount(180) - .symbolStats("a", assertion -> assertion.isEqualTo(expectedAInC)) - .symbolStats("b", assertion -> assertion.isEqualTo(bStats)) - .symbolStatsUnknown("c") - .symbolStatsUnknown("sjo")); + .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedAInC)) + .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(bStats)) + .variableStatsUnknown("c") + .variableStatsUnknown("sjo")); } @Test(dataProvider = "toRowExpression") public void testFilterPositiveNarrowingProjectSemiJoin(boolean toRowExpression) { tester().assertStatsFor(pb -> { - Symbol a = pb.symbol("a", BIGINT); - Symbol b = pb.symbol("b", BIGINT); - Symbol c = pb.symbol("c", BIGINT); - Symbol semiJoinOutput = pb.symbol("sjo", BOOLEAN); + VariableReferenceExpression a = pb.variable("a", BIGINT); + VariableReferenceExpression b = pb.variable("b", BIGINT); + VariableReferenceExpression c = pb.variable("c", BIGINT); + VariableReferenceExpression semiJoinOutput = pb.variable("sjo", BOOLEAN); PlanNode semiJoinNode = pb.semiJoin( pb.values(LEFT_SOURCE_ID, a, b), @@ -132,18 +133,18 @@ public void testFilterPositiveNarrowingProjectSemiJoin(boolean toRowExpression) }) .withSourceStats(LEFT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(1000) - .addSymbolStatistics(new Symbol("a"), aStats) - .addSymbolStatistics(new Symbol("b"), bStats) + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), aStats) + .addVariableStatistics(new VariableReferenceExpression("b", BIGINT), bStats) .build()) .withSourceStats(RIGHT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(2000) - .addSymbolStatistics(new Symbol("c"), cStats) + .addVariableStatistics(new VariableReferenceExpression("c", BIGINT), cStats) .build()) .check(check -> check.outputRowsCount(180) - .symbolStats("a", assertion -> assertion.isEqualTo(expectedAInC)) - .symbolStatsUnknown("b") - .symbolStatsUnknown("c") - .symbolStatsUnknown("sjo")); + .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedAInC)) + .variableStatsUnknown("b") + .variableStatsUnknown("c") + .variableStatsUnknown("sjo")); } @Test(dataProvider = "toRowExpression") @@ -152,18 +153,18 @@ public void testFilterPositivePlusExtraConjunctSemiJoin(boolean toRowExpression) getStatsCalculatorAssertion(expression("sjo AND a < 8"), toRowExpression) .withSourceStats(LEFT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(1000) - .addSymbolStatistics(new Symbol("a"), aStats) - .addSymbolStatistics(new Symbol("b"), bStats) + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), aStats) + .addVariableStatistics(new VariableReferenceExpression("b", BIGINT), bStats) .build()) .withSourceStats(RIGHT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(2000) - .addSymbolStatistics(new Symbol("c"), cStats) + .addVariableStatistics(new VariableReferenceExpression("c", BIGINT), cStats) .build()) .check(check -> check.outputRowsCount(144) - .symbolStats("a", assertion -> assertion.isEqualTo(expectedANotInC)) - .symbolStats("b", assertion -> assertion.isEqualTo(bStats)) - .symbolStatsUnknown("c") - .symbolStatsUnknown("sjo")); + .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedANotInC)) + .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(bStats)) + .variableStatsUnknown("c") + .variableStatsUnknown("sjo")); } @Test(dataProvider = "toRowExpression") @@ -172,27 +173,27 @@ public void testFilterNegativeSemiJoin(boolean toRowExpression) getStatsCalculatorAssertion(expression("NOT sjo"), toRowExpression) .withSourceStats(LEFT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(1000) - .addSymbolStatistics(new Symbol("a"), aStats) - .addSymbolStatistics(new Symbol("b"), bStats) + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), aStats) + .addVariableStatistics(new VariableReferenceExpression("b", BIGINT), bStats) .build()) .withSourceStats(RIGHT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(2000) - .addSymbolStatistics(new Symbol("c"), cStats) + .addVariableStatistics(new VariableReferenceExpression("c", BIGINT), cStats) .build()) .check(check -> check.outputRowsCount(720) - .symbolStats("a", assertion -> assertion.isEqualTo(expectedANotInCWithExtraFilter)) - .symbolStats("b", assertion -> assertion.isEqualTo(bStats)) - .symbolStatsUnknown("c") - .symbolStatsUnknown("sjo")); + .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedANotInCWithExtraFilter)) + .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(bStats)) + .variableStatsUnknown("c") + .variableStatsUnknown("sjo")); } private StatsCalculatorAssertion getStatsCalculatorAssertion(Expression expression, boolean toRowExpression) { return tester().assertStatsFor(pb -> { - Symbol a = pb.symbol("a", BIGINT); - Symbol b = pb.symbol("b", BIGINT); - Symbol c = pb.symbol("c", BIGINT); - Symbol semiJoinOutput = pb.symbol("sjo", BOOLEAN); + VariableReferenceExpression a = pb.variable("a", BIGINT); + VariableReferenceExpression b = pb.variable("b", BIGINT); + VariableReferenceExpression c = pb.variable("c", BIGINT); + VariableReferenceExpression semiJoinOutput = pb.variable("sjo", BOOLEAN); PlanNode semiJoinNode = pb.semiJoin( pb.values(LEFT_SOURCE_ID, a, b), diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestSortNodeStats.java b/presto-main/src/test/java/com/facebook/presto/cost/TestSortNodeStats.java index 7eff276a1b383..daf173e91f464 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestSortNodeStats.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestSortNodeStats.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import org.testng.annotations.Test; import static com.facebook.presto.spi.type.BigintType.BIGINT; @@ -28,17 +28,17 @@ public void testStatsForSortNode() { PlanNodeStatsEstimate stats = PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .addSymbolStatistics( - new Symbol("a"), - SymbolStatsEstimate.builder() + .addVariableStatistics( + new VariableReferenceExpression("a", BIGINT), + VariableStatsEstimate.builder() .setNullsFraction(0.3) .setLowValue(1) .setHighValue(30) .setDistinctValuesCount(20) .build()) - .addSymbolStatistics( - new Symbol("b"), - SymbolStatsEstimate.builder() + .addVariableStatistics( + new VariableReferenceExpression("b", DOUBLE), + VariableStatsEstimate.builder() .setNullsFraction(0.6) .setLowValue(13.5) .setHighValue(POSITIVE_INFINITY) @@ -48,8 +48,8 @@ public void testStatsForSortNode() tester().assertStatsFor(pb -> pb .output(outputBuilder -> { - Symbol a = pb.symbol("a", BIGINT); - Symbol b = pb.symbol("b", DOUBLE); + VariableReferenceExpression a = pb.variable(pb.symbol("a", BIGINT)); + VariableReferenceExpression b = pb.variable(pb.symbol("b", DOUBLE)); outputBuilder .source(pb.values(a, b)) .column(a, "a1") diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestStatsNormalizer.java b/presto-main/src/test/java/com/facebook/presto/cost/TestStatsNormalizer.java index 21bed76225e3a..13b80d6babe9c 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestStatsNormalizer.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestStatsNormalizer.java @@ -16,15 +16,13 @@ import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.testing.TestingConnectorSession; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; import java.time.LocalDate; @@ -38,10 +36,8 @@ import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.SmallintType.SMALLINT; import static com.facebook.presto.spi.type.TinyintType.TINYINT; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.Double.NaN; import static java.util.Collections.emptyList; -import static java.util.function.Function.identity; public class TestStatsNormalizer { @@ -54,51 +50,51 @@ public class TestStatsNormalizer @Test public void testNoCapping() { - Symbol a = new Symbol("a"); + VariableReferenceExpression a = new VariableReferenceExpression("a", BIGINT); PlanNodeStatsEstimate estimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(30) - .addSymbolStatistics(a, SymbolStatsEstimate.builder().setDistinctValuesCount(20).build()) + .addVariableStatistics(a, VariableStatsEstimate.builder().setDistinctValuesCount(20).build()) .build(); assertNormalized(estimate) - .symbolStats(a, symbolAssert -> symbolAssert.distinctValuesCount(20)); + .variableStats(a, variableAssert -> variableAssert.distinctValuesCount(20)); } @Test public void testDropNonOutputSymbols() { - Symbol a = new Symbol("a"); - Symbol b = new Symbol("b"); - Symbol c = new Symbol("c"); + VariableReferenceExpression a = new VariableReferenceExpression("a", BIGINT); + VariableReferenceExpression b = new VariableReferenceExpression("b", BIGINT); + VariableReferenceExpression c = new VariableReferenceExpression("c", BIGINT); PlanNodeStatsEstimate estimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(40) - .addSymbolStatistics(a, SymbolStatsEstimate.builder().setDistinctValuesCount(20).build()) - .addSymbolStatistics(b, SymbolStatsEstimate.builder().setDistinctValuesCount(30).build()) - .addSymbolStatistics(c, SymbolStatsEstimate.unknown()) + .addVariableStatistics(a, VariableStatsEstimate.builder().setDistinctValuesCount(20).build()) + .addVariableStatistics(b, VariableStatsEstimate.builder().setDistinctValuesCount(30).build()) + .addVariableStatistics(c, VariableStatsEstimate.unknown()) .build(); - PlanNodeStatsAssertion.assertThat(normalizer.normalize(estimate, ImmutableList.of(b, c), TypeProvider.copyOf(ImmutableMap.of(b, BIGINT, c, BIGINT)))) - .symbolsWithKnownStats(b) - .symbolStats(b, symbolAssert -> symbolAssert.distinctValuesCount(30)); + PlanNodeStatsAssertion.assertThat(normalizer.normalize(estimate, ImmutableList.of(b, c))) + .variablesWithKnownStats(b) + .variableStats(b, variableAssert -> variableAssert.distinctValuesCount(30)); } @Test public void tesCapDistinctValuesByOutputRowCount() { - Symbol a = new Symbol("a"); - Symbol b = new Symbol("b"); - Symbol c = new Symbol("c"); + VariableReferenceExpression a = new VariableReferenceExpression("a", BIGINT); + VariableReferenceExpression b = new VariableReferenceExpression("b", BIGINT); + VariableReferenceExpression c = new VariableReferenceExpression("c", BIGINT); PlanNodeStatsEstimate estimate = PlanNodeStatsEstimate.builder() - .addSymbolStatistics(a, SymbolStatsEstimate.builder().setNullsFraction(0).setDistinctValuesCount(20).build()) - .addSymbolStatistics(b, SymbolStatsEstimate.builder().setNullsFraction(0.4).setDistinctValuesCount(20).build()) - .addSymbolStatistics(c, SymbolStatsEstimate.unknown()) + .addVariableStatistics(a, VariableStatsEstimate.builder().setNullsFraction(0).setDistinctValuesCount(20).build()) + .addVariableStatistics(b, VariableStatsEstimate.builder().setNullsFraction(0.4).setDistinctValuesCount(20).build()) + .addVariableStatistics(c, VariableStatsEstimate.unknown()) .setOutputRowCount(10) .build(); assertNormalized(estimate) - .symbolStats(a, symbolAssert -> symbolAssert.distinctValuesCount(10)) - .symbolStats(b, symbolAssert -> symbolAssert.distinctValuesCount(8)) - .symbolStats(c, SymbolStatsAssertion::distinctValuesCountUnknown); + .variableStats(a, variableAssert -> variableAssert.distinctValuesCount(10)) + .variableStats(b, variableAssert -> variableAssert.distinctValuesCount(8)) + .variableStats(c, VariableStatsAssertion::distinctValuesCountUnknown); } @Test @@ -133,8 +129,8 @@ public void testCapDistinctValuesByToDomainRangeLength() private void testCapDistinctValuesByToDomainRangeLength(Type type, double ndv, Object low, Object high, double expectedNormalizedNdv) { - Symbol symbol = new Symbol("x"); - SymbolStatsEstimate symbolStats = SymbolStatsEstimate.builder() + VariableReferenceExpression variable = new VariableReferenceExpression("x", type); + VariableStatsEstimate symbolStats = VariableStatsEstimate.builder() .setNullsFraction(0) .setDistinctValuesCount(ndv) .setLowValue(asStatsValue(low, type)) @@ -142,22 +138,15 @@ private void testCapDistinctValuesByToDomainRangeLength(Type type, double ndv, O .build(); PlanNodeStatsEstimate estimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(10000000000L) - .addSymbolStatistics(symbol, symbolStats).build(); + .addVariableStatistics(variable, symbolStats).build(); - assertNormalized(estimate, TypeProvider.copyOf(ImmutableMap.of(symbol, type))) - .symbolStats(symbol, symbolAssert -> symbolAssert.distinctValuesCount(expectedNormalizedNdv)); + assertNormalized(estimate) + .variableStats(variable, variableAssert -> variableAssert.distinctValuesCount(expectedNormalizedNdv)); } private PlanNodeStatsAssertion assertNormalized(PlanNodeStatsEstimate estimate) { - TypeProvider types = TypeProvider.copyOf(estimate.getSymbolsWithKnownStatistics().stream() - .collect(toImmutableMap(identity(), symbol -> BIGINT))); - return assertNormalized(estimate, types); - } - - private PlanNodeStatsAssertion assertNormalized(PlanNodeStatsEstimate estimate, TypeProvider types) - { - PlanNodeStatsEstimate normalized = normalizer.normalize(estimate, estimate.getSymbolsWithKnownStatistics(), types); + PlanNodeStatsEstimate normalized = normalizer.normalize(estimate, estimate.getVariablesWithKnownStatistics()); return PlanNodeStatsAssertion.assertThat(normalized); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestUnionStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestUnionStatsRule.java index 6998959a4f386..3630e9a91fd98 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestUnionStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestUnionStatsRule.java @@ -14,7 +14,7 @@ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import org.testng.annotations.Test; @@ -38,43 +38,43 @@ public void testUnion() tester().assertStatsFor(pb -> pb .union( - ImmutableListMultimap.builder() - .putAll(pb.symbol("o1", BIGINT), pb.symbol("i11", BIGINT), pb.symbol("i21", BIGINT)) - .putAll(pb.symbol("o2", BIGINT), pb.symbol("i12", BIGINT), pb.symbol("i22", BIGINT)) - .putAll(pb.symbol("o3", BIGINT), pb.symbol("i13", BIGINT), pb.symbol("i23", BIGINT)) - .putAll(pb.symbol("o4", BIGINT), pb.symbol("i14", BIGINT), pb.symbol("i24", BIGINT)) - .putAll(pb.symbol("o5", BIGINT), pb.symbol("i15", BIGINT), pb.symbol("i25", BIGINT)) + ImmutableListMultimap.builder() + .putAll(pb.variable("o1"), pb.variable("i11"), pb.variable("i21")) + .putAll(pb.variable("o2"), pb.variable("i12"), pb.variable("i22")) + .putAll(pb.variable("o3"), pb.variable("i13"), pb.variable("i23")) + .putAll(pb.variable("o4"), pb.variable("i14"), pb.variable("i24")) + .putAll(pb.variable("o5"), pb.variable("i15"), pb.variable("i25")) .build(), ImmutableList.of( - pb.values(pb.symbol("i11", BIGINT), pb.symbol("i12", BIGINT), pb.symbol("i13", BIGINT), pb.symbol("i14", BIGINT), pb.symbol("i15", BIGINT)), - pb.values(pb.symbol("i21", BIGINT), pb.symbol("i22", BIGINT), pb.symbol("i23", BIGINT), pb.symbol("i24", BIGINT), pb.symbol("i25", BIGINT))))) + pb.values(pb.variable("i11"), pb.variable("i12"), pb.variable("i13"), pb.variable("i14"), pb.variable("i15")), + pb.values(pb.variable("i21"), pb.variable("i22"), pb.variable("i23"), pb.variable("i24"), pb.variable("i25"))))) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i11"), VariableStatsEstimate.builder() .setLowValue(1) .setHighValue(10) .setDistinctValuesCount(5) .setNullsFraction(0.3) .build()) - .addSymbolStatistics(new Symbol("i12"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i12"), VariableStatsEstimate.builder() .setLowValue(0) .setHighValue(3) .setDistinctValuesCount(4) .setNullsFraction(0) .build()) - .addSymbolStatistics(new Symbol("i13"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i13"), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) .setNullsFraction(0.1) .build()) - .addSymbolStatistics(new Symbol("i14"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i14"), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) .setNullsFraction(0.1) .build()) - .addSymbolStatistics(new Symbol("i15"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i15"), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) @@ -83,56 +83,61 @@ public void testUnion() .build()) .withSourceStats(1, PlanNodeStatsEstimate.builder() .setOutputRowCount(20) - .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i21"), VariableStatsEstimate.builder() .setLowValue(11) .setHighValue(20) .setNullsFraction(0.4) .build()) - .addSymbolStatistics(new Symbol("i22"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i22"), VariableStatsEstimate.builder() .setLowValue(2) .setHighValue(7) .setDistinctValuesCount(3) .build()) - .addSymbolStatistics(new Symbol("i23"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i23"), VariableStatsEstimate.builder() .setDistinctValuesCount(6) .setNullsFraction(0.2) .build()) - .addSymbolStatistics(new Symbol("i24"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i24"), VariableStatsEstimate.builder() .setLowValue(10) .setHighValue(15) .setDistinctValuesCount(4) .setNullsFraction(0.1) .build()) - .addSymbolStatistics(new Symbol("i25"), SymbolStatsEstimate.builder() + .addVariableStatistics(variable("i25"), VariableStatsEstimate.builder() .setNullsFraction(1) .build()) .build()) .check(check -> check .outputRowsCount(30) - .symbolStats("o1", assertion -> assertion + .variableStats(variable("o1"), assertion -> assertion .lowValue(1) .highValue(20) .dataSizeUnknown() .nullsFraction(0.3666666)) - .symbolStats("o2", assertion -> assertion + .variableStats(variable("o2"), assertion -> assertion .lowValue(0) .highValue(7) .distinctValuesCount(6.4) .nullsFractionUnknown()) - .symbolStats("o3", assertion -> assertion + .variableStats(variable("o3"), assertion -> assertion .lowValueUnknown() .highValueUnknown() .distinctValuesCount(8.5) .nullsFraction(0.1666667)) - .symbolStats("o4", assertion -> assertion + .variableStats(variable("o4"), assertion -> assertion .lowValue(10) .highValue(15) .distinctValuesCount(4.0) .nullsFraction(0.1)) - .symbolStats("o5", assertion -> assertion + .variableStats(variable("o5"), assertion -> assertion .lowValue(NEGATIVE_INFINITY) .highValue(POSITIVE_INFINITY) .distinctValuesCountUnknown() .nullsFraction(0.7))); } + + private VariableReferenceExpression variable(String name) + { + return new VariableReferenceExpression(name, BIGINT); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestUnnestStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestUnnestStatsRule.java index cadb2a8bb542a..0c8d40a5f43ef 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestUnnestStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestUnnestStatsRule.java @@ -13,13 +13,16 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; import java.util.Optional; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.util.StructuralTestUtil.mapType; + public class TestUnnestStatsRule extends BaseStatsCalculatorTest { @@ -28,13 +31,13 @@ public void testUnnestStatsNotPopulatedForMultiRow() { tester().assertStatsFor( pb -> pb.unnest( - pb.values(pb.symbol("some_map")), - ImmutableList.of(pb.symbol("some_map")), - ImmutableMap.of(pb.symbol("some_map"), ImmutableList.of(pb.symbol("key"), pb.symbol("value"))), + pb.values(pb.variable("some_map", mapType(VARCHAR, VARCHAR))), + ImmutableList.of(pb.variable("some_map", mapType(VARCHAR, VARCHAR))), + ImmutableMap.of(pb.variable("some_map", mapType(VARCHAR, VARCHAR)), ImmutableList.of(pb.variable("key"), pb.variable("value"))), Optional.empty())) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(2) - .addSymbolStatistics(new Symbol("some_map"), SymbolStatsEstimate.builder().setAverageRowSize(100).build()) + .addVariableStatistics(new VariableReferenceExpression("some_map", mapType(VARCHAR, VARCHAR)), VariableStatsEstimate.builder().setAverageRowSize(100).build()) .build()) .check(check -> check.equalTo(PlanNodeStatsEstimate.unknown())); } @@ -44,18 +47,18 @@ public void testUnntestStatsPopulated() { tester().assertStatsFor( pb -> pb.unnest( - pb.values(pb.symbol("some_map")), - ImmutableList.of(pb.symbol("some_map")), - ImmutableMap.of(pb.symbol("some_map"), ImmutableList.of(pb.symbol("key"), pb.symbol("value"))), + pb.values(pb.variable("some_map", mapType(VARCHAR, VARCHAR))), + ImmutableList.of(pb.variable("some_map", mapType(VARCHAR, VARCHAR))), + ImmutableMap.of(pb.variable("some_map", mapType(VARCHAR, VARCHAR)), ImmutableList.of(pb.variable("key", VARCHAR), pb.variable("value", VARCHAR))), Optional.empty())) .withSourceStats(0, PlanNodeStatsEstimate.builder() .setOutputRowCount(1) - .addSymbolStatistics(new Symbol("some_map"), SymbolStatsEstimate.builder().setAverageRowSize(100).build()) + .addVariableStatistics(new VariableReferenceExpression("some_map", mapType(VARCHAR, VARCHAR)), VariableStatsEstimate.builder().setAverageRowSize(100).build()) .build()) .check(check -> check .outputRowsCount(1) - .symbolStats("some_map", assertion -> assertion.averageRowSize(100)) - .symbolStats("key", assertion -> assertion.averageRowSize(100)) - .symbolStats("value", assertion -> assertion.averageRowSize(100))); + .variableStats(new VariableReferenceExpression("some_map", mapType(VARCHAR, VARCHAR)), assertion -> assertion.averageRowSize(100)) + .variableStats(new VariableReferenceExpression("key", VARCHAR), assertion -> assertion.averageRowSize(100)) + .variableStats(new VariableReferenceExpression("value", VARCHAR), assertion -> assertion.averageRowSize(100))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java b/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java index e0c8799f8f230..3a62874be75f3 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.relational.FunctionResolution; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -37,7 +37,8 @@ public void testStatsForValuesNode() { FunctionResolution resolution = new FunctionResolution(tester().getMetadata().getFunctionManager()); tester().assertStatsFor(pb -> pb - .values(ImmutableList.of(pb.symbol("a", BIGINT), pb.symbol("b", DOUBLE)), + .values( + ImmutableList.of(pb.variable("a", BIGINT), pb.variable("b", DOUBLE)), ImmutableList.of( ImmutableList.of(call(ADD.name(), resolution.arithmeticFunction(ADD, BIGINT, BIGINT), BIGINT, constantExpressions(BIGINT, 3L, 3L)), constant(13.5, DOUBLE)), ImmutableList.of(constant(55, BIGINT), constantNull(DOUBLE)), @@ -45,17 +46,17 @@ public void testStatsForValuesNode() .check(outputStats -> outputStats.equalTo( PlanNodeStatsEstimate.builder() .setOutputRowCount(3) - .addSymbolStatistics( - new Symbol("a"), - SymbolStatsEstimate.builder() + .addVariableStatistics( + new VariableReferenceExpression("a", BIGINT), + VariableStatsEstimate.builder() .setNullsFraction(0) .setLowValue(6) .setHighValue(55) .setDistinctValuesCount(2) .build()) - .addSymbolStatistics( - new Symbol("b"), - SymbolStatsEstimate.builder() + .addVariableStatistics( + new VariableReferenceExpression("b", DOUBLE), + VariableStatsEstimate.builder() .setNullsFraction(0.33333333333333333) .setLowValue(13.5) .setHighValue(13.5) @@ -64,7 +65,8 @@ public void testStatsForValuesNode() .build())); tester().assertStatsFor(pb -> pb - .values(ImmutableList.of(pb.symbol("v", createVarcharType(30))), + .values( + ImmutableList.of(pb.variable("v", createVarcharType(30))), ImmutableList.of( constantExpressions(VARCHAR, "Alice"), constantExpressions(VARCHAR, "has"), @@ -73,9 +75,9 @@ public void testStatsForValuesNode() .check(outputStats -> outputStats.equalTo( PlanNodeStatsEstimate.builder() .setOutputRowCount(4) - .addSymbolStatistics( - new Symbol("v"), - SymbolStatsEstimate.builder() + .addVariableStatistics( + new VariableReferenceExpression("v", createVarcharType(30)), + VariableStatsEstimate.builder() .setNullsFraction(0.25) .setDistinctValuesCount(3) // TODO .setAverageRowSize(4 + 1. / 3) @@ -87,38 +89,47 @@ public void testStatsForValuesNode() public void testStatsForValuesNodeWithJustNulls() { FunctionResolution resolution = new FunctionResolution(tester().getMetadata().getFunctionManager()); - PlanNodeStatsEstimate nullAStats = PlanNodeStatsEstimate.builder() + PlanNodeStatsEstimate bigintNullAStats = PlanNodeStatsEstimate.builder() .setOutputRowCount(1) - .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.zero()) + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), VariableStatsEstimate.zero()) .build(); tester().assertStatsFor(pb -> pb - .values(ImmutableList.of(pb.symbol("a", BIGINT)), + .values( + ImmutableList.of(pb.variable("a", BIGINT)), ImmutableList.of( ImmutableList.of(call(ADD.name(), resolution.arithmeticFunction(ADD, BIGINT, BIGINT), BIGINT, constant(3, BIGINT), constantNull(BIGINT)))))) - .check(outputStats -> outputStats.equalTo(nullAStats)); + .check(outputStats -> outputStats.equalTo(bigintNullAStats)); tester().assertStatsFor(pb -> pb - .values(ImmutableList.of(pb.symbol("a", BIGINT)), + .values( + ImmutableList.of(pb.variable("a", BIGINT)), ImmutableList.of(ImmutableList.of(constantNull(BIGINT))))) - .check(outputStats -> outputStats.equalTo(nullAStats)); + .check(outputStats -> outputStats.equalTo(bigintNullAStats)); + + PlanNodeStatsEstimate unknownNullAStats = PlanNodeStatsEstimate.builder() + .setOutputRowCount(1) + .addVariableStatistics(new VariableReferenceExpression("a", UNKNOWN), VariableStatsEstimate.zero()) + .build(); tester().assertStatsFor(pb -> pb - .values(ImmutableList.of(pb.symbol("a", UNKNOWN)), + .values( + ImmutableList.of(pb.variable("a", UNKNOWN)), ImmutableList.of(ImmutableList.of(constantNull(UNKNOWN))))) - .check(outputStats -> outputStats.equalTo(nullAStats)); + .check(outputStats -> outputStats.equalTo(unknownNullAStats)); } @Test public void testStatsForEmptyValues() { tester().assertStatsFor(pb -> pb - .values(ImmutableList.of(pb.symbol("a", BIGINT)), + .values( + ImmutableList.of(pb.variable("a", BIGINT)), ImmutableList.of())) .check(outputStats -> outputStats.equalTo( PlanNodeStatsEstimate.builder() .setOutputRowCount(0) - .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.zero()) + .addVariableStatistics(new VariableReferenceExpression("a", BIGINT), VariableStatsEstimate.zero()) .build())); } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/SymbolStatsAssertion.java b/presto-main/src/test/java/com/facebook/presto/cost/VariableStatsAssertion.java similarity index 76% rename from presto-main/src/test/java/com/facebook/presto/cost/SymbolStatsAssertion.java rename to presto-main/src/test/java/com/facebook/presto/cost/VariableStatsAssertion.java index 17af5066e3326..74c897aa34e9b 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/SymbolStatsAssertion.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/VariableStatsAssertion.java @@ -21,50 +21,50 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; -public class SymbolStatsAssertion +public class VariableStatsAssertion { - private final SymbolStatsEstimate statistics; + private final VariableStatsEstimate statistics; - private SymbolStatsAssertion(SymbolStatsEstimate statistics) + private VariableStatsAssertion(VariableStatsEstimate statistics) { this.statistics = requireNonNull(statistics, "statistics is null"); } - public static SymbolStatsAssertion assertThat(SymbolStatsEstimate actual) + public static VariableStatsAssertion assertThat(VariableStatsEstimate actual) { - return new SymbolStatsAssertion(actual); + return new VariableStatsAssertion(actual); } - public SymbolStatsAssertion nullsFraction(double expected) + public VariableStatsAssertion nullsFraction(double expected) { assertEstimateEquals(statistics.getNullsFraction(), expected, "nullsFraction mismatch"); return this; } - public SymbolStatsAssertion nullsFractionUnknown() + public VariableStatsAssertion nullsFractionUnknown() { assertTrue(isNaN(statistics.getNullsFraction()), "expected unknown nullsFraction but got " + statistics.getNullsFraction()); return this; } - public SymbolStatsAssertion lowValue(double expected) + public VariableStatsAssertion lowValue(double expected) { assertEstimateEquals(statistics.getLowValue(), expected, "lowValue mismatch"); return this; } - public SymbolStatsAssertion lowValueUnknown() + public VariableStatsAssertion lowValueUnknown() { return lowValue(NEGATIVE_INFINITY); } - public SymbolStatsAssertion highValue(double expected) + public VariableStatsAssertion highValue(double expected) { assertEstimateEquals(statistics.getHighValue(), expected, "highValue mismatch"); return this; } - public SymbolStatsAssertion highValueUnknown() + public VariableStatsAssertion highValueUnknown() { return highValue(POSITIVE_INFINITY); } @@ -76,7 +76,7 @@ public void empty() .nullsFraction(1); } - public SymbolStatsAssertion emptyRange() + public VariableStatsAssertion emptyRange() { assertTrue(isNaN(statistics.getLowValue()) && isNaN(statistics.getHighValue()), "expected empty range (NaN, NaN) but got (" + statistics.getLowValue() + ", " + statistics.getHighValue() + ") instead"); @@ -86,37 +86,37 @@ public SymbolStatsAssertion emptyRange() return this; } - public SymbolStatsAssertion unknownRange() + public VariableStatsAssertion unknownRange() { return lowValueUnknown() .highValueUnknown(); } - public SymbolStatsAssertion distinctValuesCount(double expected) + public VariableStatsAssertion distinctValuesCount(double expected) { assertEstimateEquals(statistics.getDistinctValuesCount(), expected, "distinctValuesCount mismatch"); return this; } - public SymbolStatsAssertion distinctValuesCountUnknown() + public VariableStatsAssertion distinctValuesCountUnknown() { assertTrue(isNaN(statistics.getDistinctValuesCount()), "expected unknown distinctValuesCount but got " + statistics.getDistinctValuesCount()); return this; } - public SymbolStatsAssertion averageRowSize(double expected) + public VariableStatsAssertion averageRowSize(double expected) { assertEstimateEquals(statistics.getAverageRowSize(), expected, "average row size mismatch"); return this; } - public SymbolStatsAssertion dataSizeUnknown() + public VariableStatsAssertion dataSizeUnknown() { assertTrue(isNaN(statistics.getAverageRowSize()), "expected unknown dataSize but got " + statistics.getAverageRowSize()); return this; } - public SymbolStatsAssertion isEqualTo(SymbolStatsEstimate expected) + public VariableStatsAssertion isEqualTo(VariableStatsEstimate expected) { return nullsFraction(expected.getNullsFraction()) .lowValue(expected.getLowValue()) diff --git a/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java b/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java index d202baa2285ac..0375eb2ee7ed7 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java @@ -31,11 +31,11 @@ import com.facebook.presto.operator.TaskStats; import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spiller.SpillSpaceTracker; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanFragment; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TableScanNode; @@ -105,19 +105,19 @@ public MockRemoteTaskFactory(Executor executor, ScheduledExecutorService schedul public MockRemoteTask createTableScanTask(TaskId taskId, InternalNode newNode, List splits, PartitionedSplitCountTracker partitionedSplitCountTracker) { - Symbol symbol = new Symbol("column"); + VariableReferenceExpression variable = new VariableReferenceExpression("column", VARCHAR); PlanNodeId sourceId = new PlanNodeId("sourceId"); PlanFragment testFragment = new PlanFragment( new PlanFragmentId(0), new TableScanNode( sourceId, new TableHandle(new ConnectorId("test"), new TestingTableHandle(), TestingTransactionHandle.create(), Optional.of(TestingHandle.INSTANCE)), - ImmutableList.of(symbol), - ImmutableMap.of(symbol, new TestingColumnHandle("column"))), - ImmutableMap.of(symbol, VARCHAR), + ImmutableList.of(variable), + ImmutableMap.of(variable, new TestingColumnHandle("column"))), + ImmutableSet.of(variable), SOURCE_DISTRIBUTION, ImmutableList.of(sourceId), - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)), StageExecutionDescriptor.ungroupedExecution(), false, StatsAndCosts.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java b/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java index 767470315c18e..26f19d597dc1a 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java @@ -34,6 +34,7 @@ import com.facebook.presto.operator.index.IndexJoinLookupStats; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.TestingTypeManager; import com.facebook.presto.spiller.GenericSpillerFactory; import com.facebook.presto.split.PageSinkManager; @@ -60,6 +61,7 @@ import com.facebook.presto.util.FinalizerService; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.airlift.json.ObjectMapperProvider; import java.util.List; @@ -68,7 +70,6 @@ import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static io.airlift.json.JsonCodec.jsonCodec; @@ -91,17 +92,19 @@ private TaskTestUtils() public static final Symbol SYMBOL = new Symbol("column"); + public static final VariableReferenceExpression VARIABLE = new VariableReferenceExpression("column", BIGINT); + public static final PlanFragment PLAN_FRAGMENT = new PlanFragment( new PlanFragmentId(0), new TableScanNode( TABLE_SCAN_NODE_ID, new TableHandle(CONNECTOR_ID, new TestingTableHandle(), TRANSACTION_HANDLE, Optional.empty()), - ImmutableList.of(SYMBOL), - ImmutableMap.of(SYMBOL, new TestingColumnHandle("column", 0, BIGINT))), - ImmutableMap.of(SYMBOL, VARCHAR), + ImmutableList.of(VARIABLE), + ImmutableMap.of(VARIABLE, new TestingColumnHandle("column", 0, BIGINT))), + ImmutableSet.of(VARIABLE), SOURCE_DISTRIBUTION, ImmutableList.of(TABLE_SCAN_NODE_ID), - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(SYMBOL)) + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(VARIABLE)) .withBucketToPartition(Optional.of(new int[1])), StageExecutionDescriptor.ungroupedExecution(), false, diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java b/presto-main/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java index e8a8468833182..c2b9f3911742b 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java @@ -22,17 +22,16 @@ import com.facebook.presto.operator.StageExecutionDescriptor; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanFragment; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.util.FinalizerService; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.SettableFuture; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -161,21 +160,17 @@ private static PlanFragment createExchangePlanFragment() PlanNode planNode = new RemoteSourceNode( new PlanNodeId("exchange"), ImmutableList.of(new PlanFragmentId(0)), - ImmutableList.of(new Symbol("column")), + ImmutableList.of(new VariableReferenceExpression("column", VARCHAR)), Optional.empty(), REPARTITION); - ImmutableMap.Builder types = ImmutableMap.builder(); - for (Symbol symbol : planNode.getOutputSymbols()) { - types.put(symbol, VARCHAR); - } return new PlanFragment( new PlanFragmentId(0), planNode, - types.build(), + ImmutableSet.copyOf(planNode.getOutputVariables()), SOURCE_DISTRIBUTION, ImmutableList.of(planNode.getId()), - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputVariables()), StageExecutionDescriptor.ungroupedExecution(), false, StatsAndCosts.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestStageStateMachine.java b/presto-main/src/test/java/com/facebook/presto/execution/TestStageStateMachine.java index f1db221615dae..4daed76d7b3dc 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestStageStateMachine.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestStageStateMachine.java @@ -18,14 +18,14 @@ import com.facebook.presto.operator.StageExecutionDescriptor; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanFragment; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; @@ -321,17 +321,17 @@ private StageStateMachine createStageStateMachine() private static PlanFragment createValuesPlan() { - Symbol symbol = new Symbol("column"); + VariableReferenceExpression variable = new VariableReferenceExpression("column", VARCHAR); PlanNodeId valuesNodeId = new PlanNodeId("plan"); PlanFragment planFragment = new PlanFragment( new PlanFragmentId(0), new ValuesNode(valuesNodeId, - ImmutableList.of(symbol), + ImmutableList.of(variable), ImmutableList.of(ImmutableList.of(constant("foo", VARCHAR)))), - ImmutableMap.of(symbol, VARCHAR), + ImmutableSet.of(variable), SOURCE_DISTRIBUTION, ImmutableList.of(valuesNodeId), - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)), StageExecutionDescriptor.ungroupedExecution(), false, StatsAndCosts.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java index 859be0a28e832..a9a1320aec2f0 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java @@ -18,7 +18,7 @@ import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.operator.StageExecutionDescriptor; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanFragment; @@ -44,7 +44,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; -import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; @@ -172,10 +172,14 @@ private static PlanFragment createUnionPlanFragment(String name, PlanFragment... PlanNode planNode = new UnionNode( new PlanNodeId(name + "_id"), Stream.of(fragments) - .map(fragment -> new RemoteSourceNode(new PlanNodeId(fragment.getId().toString()), fragment.getId(), fragment.getPartitioningScheme().getOutputLayout(), Optional.empty(), REPARTITION)) + .map(fragment -> new RemoteSourceNode( + new PlanNodeId(fragment.getId().toString()), + fragment.getId(), + fragment.getPartitioningScheme().getOutputLayout(), + Optional.empty(), + REPARTITION)) .collect(toImmutableList()), - ImmutableListMultimap.of(), - ImmutableList.of()); + ImmutableListMultimap.of()); return createFragment(planNode); } @@ -183,6 +187,7 @@ private static PlanFragment createUnionPlanFragment(String name, PlanFragment... private static PlanFragment createBroadcastJoinPlanFragment(String name, PlanFragment buildFragment) { Symbol symbol = new Symbol("column"); + VariableReferenceExpression variable = new VariableReferenceExpression("column", BIGINT); PlanNode tableScan = new TableScanNode( new PlanNodeId(name), new TableHandle( @@ -190,8 +195,8 @@ private static PlanFragment createBroadcastJoinPlanFragment(String name, PlanFra new TestingTableHandle(), TestingTransactionHandle.create(), Optional.empty()), - ImmutableList.of(symbol), - ImmutableMap.of(symbol, new TestingColumnHandle("column"))); + ImmutableList.of(variable), + ImmutableMap.of(variable, new TestingColumnHandle("column"))); RemoteSourceNode remote = new RemoteSourceNode(new PlanNodeId("build_id"), buildFragment.getId(), ImmutableList.of(), Optional.empty(), REPLICATE); PlanNode join = new JoinNode( @@ -200,9 +205,9 @@ private static PlanFragment createBroadcastJoinPlanFragment(String name, PlanFra tableScan, remote, ImmutableList.of(), - ImmutableList.builder() - .addAll(tableScan.getOutputSymbols()) - .addAll(remote.getOutputSymbols()) + ImmutableList.builder() + .addAll(tableScan.getOutputVariables()) + .addAll(remote.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), @@ -222,9 +227,9 @@ private static PlanFragment createJoinPlanFragment(JoinNode.Type joinType, Strin probe, build, ImmutableList.of(), - ImmutableList.builder() - .addAll(probe.getOutputSymbols()) - .addAll(build.getOutputSymbols()) + ImmutableList.builder() + .addAll(probe.getOutputVariables()) + .addAll(build.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), @@ -237,6 +242,7 @@ private static PlanFragment createJoinPlanFragment(JoinNode.Type joinType, Strin private static PlanFragment createTableScanPlanFragment(String name) { Symbol symbol = new Symbol("column"); + VariableReferenceExpression variable = new VariableReferenceExpression("column", BIGINT); PlanNode planNode = new TableScanNode( new PlanNodeId(name), new TableHandle( @@ -244,25 +250,21 @@ private static PlanFragment createTableScanPlanFragment(String name) new TestingTableHandle(), TestingTransactionHandle.create(), Optional.empty()), - ImmutableList.of(symbol), - ImmutableMap.of(symbol, new TestingColumnHandle("column"))); + ImmutableList.of(variable), + ImmutableMap.of(variable, new TestingColumnHandle("column"))); return createFragment(planNode); } private static PlanFragment createFragment(PlanNode planNode) { - ImmutableMap.Builder types = ImmutableMap.builder(); - for (Symbol symbol : planNode.getOutputSymbols()) { - types.put(symbol, VARCHAR); - } return new PlanFragment( new PlanFragmentId(nextPlanFragmentId.incrementAndGet()), planNode, - types.build(), + ImmutableSet.copyOf(planNode.getOutputVariables()), SOURCE_DISTRIBUTION, ImmutableList.of(planNode.getId()), - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputVariables()), StageExecutionDescriptor.ungroupedExecution(), false, StatsAndCosts.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java index f4789bf84ed41..bef64e9e1d774 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java @@ -38,13 +38,13 @@ import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.connector.ConnectorPartitionHandle; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.split.ConnectorAwareSplitSource; import com.facebook.presto.split.SplitSource; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.SubPlan; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; @@ -57,6 +57,7 @@ import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -450,14 +451,14 @@ private static StageScheduler getSourcePartitionedScheduler( private static SubPlan createPlan() { - Symbol symbol = new Symbol("column"); + VariableReferenceExpression variable = new VariableReferenceExpression("column", VARCHAR); // table scan with splitCount splits TableScanNode tableScan = new TableScanNode( TABLE_SCAN_NODE_ID, new TableHandle(CONNECTOR_ID, new TestingTableHandle(), TestingTransactionHandle.create(), Optional.empty()), - ImmutableList.of(symbol), - ImmutableMap.of(symbol, new TestingColumnHandle("column")), + ImmutableList.of(variable), + ImmutableMap.of(variable, new TestingColumnHandle("column")), false); RemoteSourceNode remote = new RemoteSourceNode(new PlanNodeId("remote_id"), new PlanFragmentId(0), ImmutableList.of(), Optional.empty(), GATHER); @@ -468,18 +469,18 @@ private static SubPlan createPlan() tableScan, remote, ImmutableList.of(), - ImmutableList.builder() - .addAll(tableScan.getOutputSymbols()) - .addAll(remote.getOutputSymbols()) + ImmutableList.builder() + .addAll(tableScan.getOutputVariables()) + .addAll(remote.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), - ImmutableMap.of(symbol, VARCHAR), + ImmutableSet.of(variable), SOURCE_DISTRIBUTION, ImmutableList.of(TABLE_SCAN_NODE_ID), - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)), StageExecutionDescriptor.ungroupedExecution(), false, StatsAndCosts.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java index 0d11b1ef4e4dc..d6e8e4740df72 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java @@ -48,6 +48,7 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.predicate.Utils; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.TimeZoneKey; import com.facebook.presto.spi.type.Type; import com.facebook.presto.split.PageSourceProvider; @@ -128,6 +129,7 @@ import static com.facebook.presto.sql.relational.SqlToRowExpressionTranslator.translate; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.facebook.presto.type.UnknownType.UNKNOWN; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.testing.Assertions.assertInstanceOf; @@ -165,31 +167,21 @@ public final class FunctionAssertions private static final Page ZERO_CHANNEL_PAGE = new Page(1); - private static final Map INPUT_MAPPING = ImmutableMap.builder() - .put(new Symbol("bound_long"), 0) - .put(new Symbol("bound_string"), 1) - .put(new Symbol("bound_double"), 2) - .put(new Symbol("bound_boolean"), 3) - .put(new Symbol("bound_timestamp"), 4) - .put(new Symbol("bound_pattern"), 5) - .put(new Symbol("bound_null_string"), 6) - .put(new Symbol("bound_timestamp_with_timezone"), 7) - .put(new Symbol("bound_binary_literal"), 8) - .put(new Symbol("bound_integer"), 9) + private static final Map INPUT_MAPPING = ImmutableMap.builder() + .put(new VariableReferenceExpression("bound_long", BIGINT), 0) + .put(new VariableReferenceExpression("bound_string", VARCHAR), 1) + .put(new VariableReferenceExpression("bound_double", DOUBLE), 2) + .put(new VariableReferenceExpression("bound_boolean", BOOLEAN), 3) + .put(new VariableReferenceExpression("bound_timestamp", BIGINT), 4) + .put(new VariableReferenceExpression("bound_pattern", VARCHAR), 5) + .put(new VariableReferenceExpression("bound_null_string", VARCHAR), 6) + .put(new VariableReferenceExpression("bound_timestamp_with_timezone", TIMESTAMP_WITH_TIME_ZONE), 7) + .put(new VariableReferenceExpression("bound_binary_literal", VARBINARY), 8) + .put(new VariableReferenceExpression("bound_integer", INTEGER), 9) .build(); - private static final TypeProvider SYMBOL_TYPES = TypeProvider.copyOf(ImmutableMap.builder() - .put(new Symbol("bound_long"), BIGINT) - .put(new Symbol("bound_string"), VARCHAR) - .put(new Symbol("bound_double"), DOUBLE) - .put(new Symbol("bound_boolean"), BOOLEAN) - .put(new Symbol("bound_timestamp"), BIGINT) - .put(new Symbol("bound_pattern"), VARCHAR) - .put(new Symbol("bound_null_string"), VARCHAR) - .put(new Symbol("bound_timestamp_with_timezone"), TIMESTAMP_WITH_TIME_ZONE) - .put(new Symbol("bound_binary_literal"), VARBINARY) - .put(new Symbol("bound_integer"), INTEGER) - .build()); + private static final TypeProvider SYMBOL_TYPES = TypeProvider.copyOf(INPUT_MAPPING.keySet().stream() + .collect(toImmutableMap(variable -> new Symbol(variable.getName()), VariableReferenceExpression::getType))); private static final PageSourceProvider PAGE_SOURCE_PROVIDER = new TestPageSourceProvider(); private static final PlanNodeId SOURCE_ID = new PlanNodeId("scan"); @@ -867,7 +859,7 @@ private Object interpret(Expression expression, Type expectedType, Session sessi Object result = evaluator.evaluate(symbol -> { int position = 0; - int channel = INPUT_MAPPING.get(symbol); + int channel = INPUT_MAPPING.get(new VariableReferenceExpression(symbol.getName(), SYMBOL_TYPES.get(symbol))); Type type = SYMBOL_TYPES.get(symbol); Block block = SOURCE_PAGE.getBlock(channel); @@ -964,7 +956,7 @@ private static SourceOperatorFactory compileScanFilterProject(Optional, Type> expressionTypes, Map layout) + private RowExpression toRowExpression(Expression projection, Map, Type> expressionTypes, Map layout) { return translate(projection, expressionTypes, layout, metadata.getFunctionManager(), metadata.getTypeManager(), session, false); } diff --git a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java index b676d1fd5a40b..262dc74e5adf6 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java +++ b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java @@ -38,8 +38,10 @@ import com.facebook.presto.server.smile.SmileModule; import com.facebook.presto.spi.ErrorCode; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.Serialization; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.testing.TestingHandleResolver; import com.facebook.presto.testing.TestingSplit; @@ -244,6 +246,8 @@ public void configure(Binder binder) jsonCodecBinder(binder).bindJsonCodec(TaskStatus.class); jsonCodecBinder(binder).bindJsonCodec(TaskInfo.class); jsonCodecBinder(binder).bindJsonCodec(TaskUpdateRequest.class); + jsonBinder(binder).addKeySerializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionSerializer.class); + jsonBinder(binder).addKeyDeserializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionDeserializer.class); } @Provides diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java index 6f90b337fa54a..f0df6fcc6edb5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Symbol; @@ -86,7 +87,7 @@ public class PageProcessorBenchmark private final DriverYieldSignal yieldSignal = new DriverYieldSignal(); private final Map symbolTypes = new HashMap<>(); - private final Map sourceLayout = new HashMap<>(); + private final Map sourceLayout = new HashMap<>(); private CursorProcessor cursorProcessor; private PageProcessor pageProcessor; @@ -109,9 +110,9 @@ public void setup() Type type = TYPE_MAP.get(this.type); for (int i = 0; i < columnCount; i++) { - Symbol symbol = new Symbol(type.getDisplayName().toLowerCase(ENGLISH) + i); - symbolTypes.put(symbol, type); - sourceLayout.put(symbol, i); + VariableReferenceExpression variable = new VariableReferenceExpression(type.getDisplayName().toLowerCase(ENGLISH) + i, type); + symbolTypes.put(new Symbol(variable.getName()), type); + sourceLayout.put(variable, i); } List projections = getProjections(type); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java index 4a8b5034e8b74..b954c53764137 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java @@ -23,6 +23,8 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; @@ -104,6 +106,13 @@ public class TestEffectivePredicateExtractor private static final Symbol E = new Symbol("e"); private static final Symbol F = new Symbol("f"); private static final Symbol G = new Symbol("g"); + private static final VariableReferenceExpression AV = new VariableReferenceExpression("a", BIGINT); + private static final VariableReferenceExpression BV = new VariableReferenceExpression("b", BIGINT); + private static final VariableReferenceExpression CV = new VariableReferenceExpression("c", BIGINT); + private static final VariableReferenceExpression DV = new VariableReferenceExpression("d", BIGINT); + private static final VariableReferenceExpression EV = new VariableReferenceExpression("e", BIGINT); + private static final VariableReferenceExpression FV = new VariableReferenceExpression("f", BIGINT); + private static final VariableReferenceExpression GV = new VariableReferenceExpression("g", BIGINT); private static final Expression AE = A.toSymbolReference(); private static final Expression BE = B.toSymbolReference(); private static final Expression CE = C.toSymbolReference(); @@ -112,26 +121,35 @@ public class TestEffectivePredicateExtractor private static final Expression FE = F.toSymbolReference(); private static final Expression GE = G.toSymbolReference(); + private static final TypeProvider types = TypeProvider.viewOf(ImmutableMap.builder() + .put(A, BIGINT) + .put(B, BIGINT) + .put(C, BIGINT) + .put(D, BIGINT) + .put(E, BIGINT) + .put(F, BIGINT) + .put(G, BIGINT) + .build()); private final Metadata metadata = MetadataManager.createTestMetadataManager(); private final EffectivePredicateExtractor effectivePredicateExtractor = new EffectivePredicateExtractor(new ExpressionDomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde()))); - private Map scanAssignments; + private Map scanAssignments; private TableScanNode baseTableScan; private ExpressionIdentityNormalizer expressionNormalizer; @BeforeMethod public void setUp() { - scanAssignments = ImmutableMap.builder() - .put(A, new TestingColumnHandle("a")) - .put(B, new TestingColumnHandle("b")) - .put(C, new TestingColumnHandle("c")) - .put(D, new TestingColumnHandle("d")) - .put(E, new TestingColumnHandle("e")) - .put(F, new TestingColumnHandle("f")) + scanAssignments = ImmutableMap.builder() + .put(AV, new TestingColumnHandle("a")) + .put(BV, new TestingColumnHandle("b")) + .put(CV, new TestingColumnHandle("c")) + .put(DV, new TestingColumnHandle("d")) + .put(EV, new TestingColumnHandle("e")) + .put(FV, new TestingColumnHandle("f")) .build(); - Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C, D, E, F))); + Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV, DV, EV, FV))); baseTableScan = new TableScanNode( newId(), DUAL_TABLE_HANDLE, @@ -157,15 +175,15 @@ public void testAggregation() greaterThan(AE, bigintLiteral(2)), equals(EE, FE))), ImmutableMap.of( - C, count(metadata.getFunctionManager()), - D, count(metadata.getFunctionManager())), - singleGroupingSet(ImmutableList.of(A, B, C)), + CV, count(metadata.getFunctionManager()), + DV, count(metadata.getFunctionManager())), + singleGroupingSet(ImmutableList.of(AV, BV, CV)), ImmutableList.of(), AggregationNode.Step.FINAL, Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // Rewrite in terms of group by symbols assertEquals(normalizeConjuncts(effectivePredicate), @@ -189,7 +207,7 @@ public void testGroupByEmpty() Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); assertEquals(effectivePredicate, TRUE_LITERAL); } @@ -202,7 +220,7 @@ public void testFilter() greaterThan(AE, new FunctionCall(QualifiedName.of("rand"), ImmutableList.of())), lessThan(BE, bigintLiteral(10)))); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // Non-deterministic functions should be purged assertEquals(normalizeConjuncts(effectivePredicate), @@ -218,9 +236,9 @@ public void testProject() equals(AE, BE), equals(BE, CE), lessThan(CE, bigintLiteral(10)))), - Assignments.of(D, AE, E, CE)); + Assignments.of(DV, AE, EV, CE)); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // Rewrite in terms of project output symbols assertEquals(normalizeConjuncts(effectivePredicate), @@ -238,9 +256,9 @@ public void testTopN() equals(AE, BE), equals(BE, CE), lessThan(CE, bigintLiteral(10)))), - 1, new OrderingScheme(ImmutableList.of(A), ImmutableMap.of(A, SortOrder.ASC_NULLS_LAST)), TopNNode.Step.PARTIAL); + 1, new OrderingScheme(ImmutableList.of(AV), ImmutableMap.of(AV, SortOrder.ASC_NULLS_LAST)), TopNNode.Step.PARTIAL); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // Pass through assertEquals(normalizeConjuncts(effectivePredicate), @@ -262,7 +280,7 @@ public void testLimit() 1, false); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // Pass through assertEquals(normalizeConjuncts(effectivePredicate), @@ -281,9 +299,9 @@ public void testSort() equals(AE, BE), equals(BE, CE), lessThan(CE, bigintLiteral(10)))), - new OrderingScheme(ImmutableList.of(A), ImmutableMap.of(A, SortOrder.ASC_NULLS_LAST))); + new OrderingScheme(ImmutableList.of(AV), ImmutableMap.of(AV, SortOrder.ASC_NULLS_LAST))); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // Pass through assertEquals(normalizeConjuncts(effectivePredicate), @@ -303,16 +321,16 @@ public void testWindow() equals(BE, CE), lessThan(CE, bigintLiteral(10)))), new WindowNode.Specification( - ImmutableList.of(A), + ImmutableList.of(AV), Optional.of(new OrderingScheme( - ImmutableList.of(A), - ImmutableMap.of(A, SortOrder.ASC_NULLS_LAST)))), + ImmutableList.of(AV), + ImmutableMap.of(AV, SortOrder.ASC_NULLS_LAST)))), ImmutableMap.of(), Optional.empty(), ImmutableSet.of(), 0); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // Pass through assertEquals(normalizeConjuncts(effectivePredicate), @@ -326,13 +344,13 @@ public void testWindow() public void testTableScan() { // Effective predicate is True if there is no effective predicate - Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C, D))); + Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV, DV))); PlanNode node = new TableScanNode( newId(), DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); assertEquals(effectivePredicate, BooleanLiteral.TRUE_LITERAL); node = new TableScanNode( @@ -342,7 +360,7 @@ public void testTableScan() assignments, TupleDomain.none(), TupleDomain.all()); - effectivePredicate = effectivePredicateExtractor.extract(node); + effectivePredicate = effectivePredicateExtractor.extract(node, types); assertEquals(effectivePredicate, FALSE_LITERAL); node = new TableScanNode( @@ -350,9 +368,9 @@ public void testTableScan() DUAL_TABLE_HANDLE_WITH_LAYOUT, ImmutableList.copyOf(assignments.keySet()), assignments, - TupleDomain.withColumnDomains(ImmutableMap.of(scanAssignments.get(A), Domain.singleValue(BIGINT, 1L))), + TupleDomain.withColumnDomains(ImmutableMap.of(scanAssignments.get(AV), Domain.singleValue(BIGINT, 1L))), TupleDomain.all()); - effectivePredicate = effectivePredicateExtractor.extract(node); + effectivePredicate = effectivePredicateExtractor.extract(node, types); assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(1L), AE))); node = new TableScanNode( @@ -361,10 +379,10 @@ public void testTableScan() ImmutableList.copyOf(assignments.keySet()), assignments, TupleDomain.withColumnDomains(ImmutableMap.of( - scanAssignments.get(A), Domain.singleValue(BIGINT, 1L), - scanAssignments.get(B), Domain.singleValue(BIGINT, 2L))), + scanAssignments.get(AV), Domain.singleValue(BIGINT, 1L), + scanAssignments.get(BV), Domain.singleValue(BIGINT, 2L))), TupleDomain.all()); - effectivePredicate = effectivePredicateExtractor.extract(node); + effectivePredicate = effectivePredicateExtractor.extract(node, types); assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(2L), BE), equals(bigintLiteral(1L), AE))); node = new TableScanNode( @@ -374,23 +392,22 @@ public void testTableScan() assignments, TupleDomain.all(), TupleDomain.all()); - effectivePredicate = effectivePredicateExtractor.extract(node); + effectivePredicate = effectivePredicateExtractor.extract(node, types); assertEquals(effectivePredicate, BooleanLiteral.TRUE_LITERAL); } @Test public void testUnion() { - ImmutableListMultimap symbolMapping = ImmutableListMultimap.of(A, B, A, C, A, E); + ImmutableListMultimap variableMapping = ImmutableListMultimap.of(AV, BV, AV, CV, AV, EV); PlanNode node = new UnionNode(newId(), ImmutableList.of( filter(baseTableScan, greaterThan(AE, bigintLiteral(10))), filter(baseTableScan, and(greaterThan(AE, bigintLiteral(10)), lessThan(AE, bigintLiteral(100)))), filter(baseTableScan, and(greaterThan(AE, bigintLiteral(10)), lessThan(AE, bigintLiteral(100))))), - symbolMapping, - ImmutableList.copyOf(symbolMapping.keySet())); + variableMapping); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // Only the common conjuncts can be inferred through a Union assertEquals(normalizeConjuncts(effectivePredicate), @@ -401,14 +418,14 @@ public void testUnion() public void testInnerJoin() { ImmutableList.Builder criteriaBuilder = ImmutableList.builder(); - criteriaBuilder.add(new JoinNode.EquiJoinClause(A, D)); - criteriaBuilder.add(new JoinNode.EquiJoinClause(B, E)); + criteriaBuilder.add(new JoinNode.EquiJoinClause(AV, DV)); + criteriaBuilder.add(new JoinNode.EquiJoinClause(BV, EV)); List criteria = criteriaBuilder.build(); - Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C))); + Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV))); TableScanNode leftScan = tableScanNode(leftAssignments); - Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); + Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV))); TableScanNode rightScan = tableScanNode(rightAssignments); FilterNode left = filter(leftScan, @@ -426,16 +443,16 @@ public void testInnerJoin() left, right, criteria, - ImmutableList.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + ImmutableList.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(), Optional.of(castToRowExpression(lessThanOrEqual(BE, EE))), Optional.empty(), Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // All predicates having output symbol should be carried through assertEquals(normalizeConjuncts(effectivePredicate), @@ -451,10 +468,10 @@ public void testInnerJoin() @Test public void testInnerJoinPropagatesPredicatesViaEquiConditions() { - Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C))); + Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV))); TableScanNode leftScan = tableScanNode(leftAssignments); - Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); + Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV))); TableScanNode rightScan = tableScanNode(rightAssignments); FilterNode left = filter(leftScan, equals(AE, bigintLiteral(10))); @@ -464,16 +481,16 @@ public void testInnerJoinPropagatesPredicatesViaEquiConditions() JoinNode.Type.INNER, left, rightScan, - ImmutableList.of(new JoinNode.EquiJoinClause(A, D)), - ImmutableList.builder() - .addAll(rightScan.getOutputSymbols()) + ImmutableList.of(new JoinNode.EquiJoinClause(AV, DV)), + ImmutableList.builder() + .addAll(rightScan.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); assertEquals( normalizeConjuncts(effectivePredicate), @@ -483,27 +500,27 @@ public void testInnerJoinPropagatesPredicatesViaEquiConditions() @Test public void testInnerJoinWithFalseFilter() { - Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C))); + Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV))); TableScanNode leftScan = tableScanNode(leftAssignments); - Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); + Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV))); TableScanNode rightScan = tableScanNode(rightAssignments); PlanNode node = new JoinNode(newId(), JoinNode.Type.INNER, leftScan, rightScan, - ImmutableList.of(new JoinNode.EquiJoinClause(A, D)), - ImmutableList.builder() - .addAll(leftScan.getOutputSymbols()) - .addAll(rightScan.getOutputSymbols()) + ImmutableList.of(new JoinNode.EquiJoinClause(AV, DV)), + ImmutableList.builder() + .addAll(leftScan.getOutputVariables()) + .addAll(rightScan.getOutputVariables()) .build(), Optional.of(castToRowExpression(FALSE_LITERAL)), Optional.empty(), Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); assertEquals(effectivePredicate, FALSE_LITERAL); } @@ -512,14 +529,14 @@ public void testInnerJoinWithFalseFilter() public void testLeftJoin() { ImmutableList.Builder criteriaBuilder = ImmutableList.builder(); - criteriaBuilder.add(new JoinNode.EquiJoinClause(A, D)); - criteriaBuilder.add(new JoinNode.EquiJoinClause(B, E)); + criteriaBuilder.add(new JoinNode.EquiJoinClause(AV, DV)); + criteriaBuilder.add(new JoinNode.EquiJoinClause(BV, EV)); List criteria = criteriaBuilder.build(); - Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C))); + Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV))); TableScanNode leftScan = tableScanNode(leftAssignments); - Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); + Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV))); TableScanNode rightScan = tableScanNode(rightAssignments); FilterNode left = filter(leftScan, @@ -536,16 +553,16 @@ public void testLeftJoin() left, right, criteria, - ImmutableList.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + ImmutableList.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // All right side symbols having output symbols should be checked against NULL assertEquals(normalizeConjuncts(effectivePredicate), @@ -560,12 +577,12 @@ public void testLeftJoin() @Test public void testLeftJoinWithFalseInner() { - List criteria = ImmutableList.of(new JoinNode.EquiJoinClause(A, D)); + List criteria = ImmutableList.of(new JoinNode.EquiJoinClause(AV, DV)); - Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C))); + Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV))); TableScanNode leftScan = tableScanNode(leftAssignments); - Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); + Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV))); TableScanNode rightScan = tableScanNode(rightAssignments); FilterNode left = filter(leftScan, @@ -579,16 +596,16 @@ public void testLeftJoinWithFalseInner() left, right, criteria, - ImmutableList.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + ImmutableList.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // False literal on the right side should be ignored assertEquals(normalizeConjuncts(effectivePredicate), @@ -601,14 +618,14 @@ public void testLeftJoinWithFalseInner() public void testRightJoin() { ImmutableList.Builder criteriaBuilder = ImmutableList.builder(); - criteriaBuilder.add(new JoinNode.EquiJoinClause(A, D)); - criteriaBuilder.add(new JoinNode.EquiJoinClause(B, E)); + criteriaBuilder.add(new JoinNode.EquiJoinClause(AV, DV)); + criteriaBuilder.add(new JoinNode.EquiJoinClause(BV, EV)); List criteria = criteriaBuilder.build(); - Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C))); + Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV))); TableScanNode leftScan = tableScanNode(leftAssignments); - Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); + Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV))); TableScanNode rightScan = tableScanNode(rightAssignments); FilterNode left = filter(leftScan, @@ -625,16 +642,16 @@ public void testRightJoin() left, right, criteria, - ImmutableList.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + ImmutableList.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // All left side symbols should be checked against NULL assertEquals(normalizeConjuncts(effectivePredicate), @@ -649,12 +666,12 @@ public void testRightJoin() @Test public void testRightJoinWithFalseInner() { - List criteria = ImmutableList.of(new JoinNode.EquiJoinClause(A, D)); + List criteria = ImmutableList.of(new JoinNode.EquiJoinClause(AV, DV)); - Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C))); + Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV))); TableScanNode leftScan = tableScanNode(leftAssignments); - Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); + Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV))); TableScanNode rightScan = tableScanNode(rightAssignments); FilterNode left = filter(leftScan, FALSE_LITERAL); @@ -667,16 +684,16 @@ public void testRightJoinWithFalseInner() left, right, criteria, - ImmutableList.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + ImmutableList.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // False literal on the left side should be ignored assertEquals(normalizeConjuncts(effectivePredicate), @@ -691,19 +708,19 @@ public void testSemiJoin() PlanNode node = new SemiJoinNode(newId(), filter(baseTableScan, and(greaterThan(AE, bigintLiteral(10)), lessThan(AE, bigintLiteral(100)))), filter(baseTableScan, greaterThan(AE, bigintLiteral(5))), - A, B, C, + AV, BV, CV, Optional.empty(), Optional.empty(), Optional.empty()); - Expression effectivePredicate = effectivePredicateExtractor.extract(node); + Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); // Currently, only pull predicates through the source plan assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(and(greaterThan(AE, bigintLiteral(10)), lessThan(AE, bigintLiteral(100))))); } - private static TableScanNode tableScanNode(Map scanAssignments) + private static TableScanNode tableScanNode(Map scanAssignments) { return new TableScanNode( newId(), @@ -778,11 +795,11 @@ private Set normalizeConjuncts(Expression predicate) Set rewrittenSet = new HashSet<>(); for (Expression expression : EqualityInference.nonInferrableConjuncts(predicate)) { - Expression rewritten = inference.rewriteExpression(expression, Predicates.alwaysTrue()); + Expression rewritten = inference.rewriteExpression(expression, Predicates.alwaysTrue(), types); Preconditions.checkState(rewritten != null, "Rewrite with full symbol scope should always be possible"); rewrittenSet.add(rewritten); } - rewrittenSet.addAll(inference.generateEqualitiesPartitionedBy(Predicates.alwaysTrue()).getScopeEqualities()); + rewrittenSet.addAll(inference.generateEqualitiesPartitionedBy(Predicates.alwaysTrue(), types).getScopeEqualities()); return rewrittenSet; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java index cfb9069126f5b..22cd764f9a4f4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.ArrayConstructor; @@ -45,14 +47,18 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.QueryUtil.identifier; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; import static com.google.common.base.Predicates.not; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.function.Function.identity; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNull; @@ -77,26 +83,26 @@ public void testTransitivity() EqualityInference inference = builder.build(); assertEquals( - inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("d1", "d2")), + inference.rewriteExpression(someExpression("a1", "a2"), matchesVariables("d1", "d2"), types("a1", "a2", "b1", "b2", "c1", "c2", "d1", "d2")), someExpression("d1", "d2")); assertEquals( - inference.rewriteExpression(someExpression("a1", "c1"), matchesSymbols("b1")), + inference.rewriteExpression(someExpression("a1", "c1"), matchesVariables("b1"), types("a1", "a2", "b1", "b2", "c1", "c2", "d1", "d2")), someExpression("b1", "b1")); assertEquals( - inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "d2", "c3")), + inference.rewriteExpression(someExpression("a1", "a2"), matchesVariables("b1", "d2", "c3"), types("a1", "a2", "b1", "b2", "c1", "c2", "c3", "d1", "d2")), someExpression("b1", "d2")); // Both starting expressions should canonicalize to the same expression assertEquals( - inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")), - inference.getScopedCanonical(nameReference("b2"), matchesSymbols("c2", "d2"))); - Expression canonical = inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")); + inference.getScopedCanonical(nameReference("a2"), matchesVariables("c2", "d2"), types("a2", "b2", "c2", "d2")), + inference.getScopedCanonical(nameReference("b2"), matchesVariables("c2", "d2"), types("a2", "b2", "c2", "d2"))); + Expression canonical = inference.getScopedCanonical(nameReference("a2"), matchesVariables("c2", "d2"), types("a2", "b2", "c2", "d2")); // Given multiple translatable candidates, should choose the canonical assertEquals( - inference.rewriteExpression(someExpression("a2", "b2"), matchesSymbols("c2", "d2")), + inference.rewriteExpression(someExpression("a2", "b2"), matchesVariables("c2", "d2"), types("a2", "b2", "c2", "d2")), someExpression(canonical, canonical)); } @@ -105,7 +111,7 @@ public void testTriviallyRewritable() { EqualityInference.Builder builder = new EqualityInference.Builder(); Expression expression = builder.build() - .rewriteExpression(someExpression("a1", "a2"), matchesSymbols("a1", "a2")); + .rewriteExpression(someExpression("a1", "a2"), matchesVariables("a1", "a2"), types("a1", "a2")); assertEquals(expression, someExpression("a1", "a2")); } @@ -118,8 +124,8 @@ public void testUnrewritable() addEquality("a2", "b2", builder); EqualityInference inference = builder.build(); - assertNull(inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "c1"))); - assertNull(inference.rewriteExpression(someExpression("c1", "c2"), matchesSymbols("a1", "a2"))); + assertNull(inference.rewriteExpression(someExpression("a1", "a2"), matchesVariables("b1", "c1"), types("a1", "a2", "b1", "b2"))); + assertNull(inference.rewriteExpression(someExpression("c1", "c2"), matchesVariables("a1", "a2"), types("a1", "a2", "c1", "c2"))); } @Test @@ -131,7 +137,7 @@ public void testParseEqualityExpression() .addEquality(equals("c1", "a1")) .build(); - Expression expression = inference.rewriteExpression(someExpression("a1", "b1"), matchesSymbols("c1")); + Expression expression = inference.rewriteExpression(someExpression("a1", "b1"), matchesVariables("c1"), types("a1", "b1", "c1")); assertEquals(expression, someExpression("c1", "c1")); } @@ -164,10 +170,10 @@ public void testExtractInferrableEqualities() .build(); // Able to rewrite to c1 due to equalities - assertEquals(nameReference("c1"), inference.rewriteExpression(nameReference("a1"), matchesSymbols("c1"))); + assertEquals(nameReference("c1"), inference.rewriteExpression(nameReference("a1"), matchesVariables("c1"), types("a1", "b1", "c1", "d1"))); // But not be able to rewrite to d1 which is not connected via equality - assertNull(inference.rewriteExpression(nameReference("a1"), matchesSymbols("d1"))); + assertNull(inference.rewriteExpression(nameReference("a1"), matchesVariables("d1"), types("a1", "b1", "c1", "d1"))); } @Test @@ -182,7 +188,7 @@ public void testEqualityPartitionGeneration() EqualityInference inference = builder.build(); - EqualityInference.EqualityPartition emptyScopePartition = inference.generateEqualitiesPartitionedBy(Predicates.alwaysFalse()); + EqualityInference.EqualityPartition emptyScopePartition = inference.generateEqualitiesPartitionedBy(Predicates.alwaysFalse(), types("a1", "b1", "c1")); // Cannot generate any scope equalities with no matching symbols assertTrue(emptyScopePartition.getScopeEqualities().isEmpty()); // All equalities should be represented in the inverse scope @@ -190,21 +196,21 @@ public void testEqualityPartitionGeneration() // There should be no equalities straddling the scope assertTrue(emptyScopePartition.getScopeStraddlingEqualities().isEmpty()); - EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("c1")); + EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesVariables("c1"), types("a1", "b1", "c1")); // There should be equalities in the scope, that only use c1 and are all inferrable equalities assertFalse(equalityPartition.getScopeEqualities().isEmpty()); - assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(matchesSymbols("c1")))); + assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesVariableScope(matchesVariables("c1"), types("a1", "b1", "c1")))); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); // There should be equalities in the inverse scope, that never use c1 and are all inferrable equalities assertFalse(equalityPartition.getScopeComplementEqualities().isEmpty()); - assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(matchesSymbols("c1"))))); + assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesVariableScope(not(matchesVariables("c1")), types("a1", "b1", "c1")))); assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate())); // There should be equalities in the straddling scope, that should use both c1 and not c1 symbols assertFalse(equalityPartition.getScopeStraddlingEqualities().isEmpty()); - assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(matchesSymbols("c1")))); + assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(matchesVariables("c1"), types("a1", "b1", "c1")))); assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate())); // There should be a "full cover" of all of the equalities used @@ -215,7 +221,7 @@ public void testEqualityPartitionGeneration() .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) .build(); - EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(matchesSymbols("c1")); + EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(matchesVariables("c1"), types("a1", "b1", "c1")); assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities())); assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities())); @@ -237,21 +243,21 @@ public void testMultipleEqualitySetsPredicateGeneration() EqualityInference inference = builder.build(); // Generating equalities for disjoint groups - EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b")); + EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(variableBeginsWith("a", "b"), types("a1", "b1", "c1", "a2", "b2", "c2", "d1", "d2")); // There should be equalities in the scope, that only use a* and b* symbols and are all inferrable equalities assertFalse(equalityPartition.getScopeEqualities().isEmpty()); - assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(symbolBeginsWith("a", "b")))); + assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesVariableScope(variableBeginsWith("a", "b"), types("a1", "b1", "c1", "a2", "b2", "c2", "d1", "d2")))); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); // There should be equalities in the inverse scope, that never use a* and b* symbols and are all inferrable equalities assertFalse(equalityPartition.getScopeComplementEqualities().isEmpty()); - assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(symbolBeginsWith("a", "b"))))); + assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesVariableScope(not(variableBeginsWith("a", "b")), types("a1", "b1", "c1", "a2", "b2", "c2", "d1", "d2")))); assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate())); // There should be equalities in the straddling scope, that should use both c1 and not c1 symbols assertFalse(equalityPartition.getScopeStraddlingEqualities().isEmpty()); - assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(symbolBeginsWith("a", "b")))); + assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(variableBeginsWith("a", "b"), types("a1", "b1", "c1", "a2", "b2", "c2", "d1", "d2")))); assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate())); // Again, there should be a "full cover" of all of the equalities used @@ -262,7 +268,7 @@ public void testMultipleEqualitySetsPredicateGeneration() .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) .build(); - EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b")); + EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(variableBeginsWith("a", "b"), types("a1", "b1", "c1", "a2", "b2", "c2", "d1", "d2")); assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities())); assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities())); @@ -279,13 +285,13 @@ public void testSubExpressionRewrites() EqualityInference inference = builder.build(); // Expression (b + c) should get entirely rewritten as a1 - assertEquals(inference.rewriteExpression(add("b", "c"), symbolBeginsWith("a")), nameReference("a1")); + assertEquals(inference.rewriteExpression(add("b", "c"), variableBeginsWith("a"), types("a1", "a2", "a3", "b", "c")), nameReference("a1")); // Only the sub-expression (b + c) should get rewritten in terms of a* - assertEquals(inference.rewriteExpression(multiply(nameReference("ax"), add("b", "c")), symbolBeginsWith("a")), multiply(nameReference("ax"), nameReference("a1"))); + assertEquals(inference.rewriteExpression(multiply(nameReference("ax"), add("b", "c")), variableBeginsWith("a"), types("a1", "a2", "a3", "b", "c", "ax")), multiply(nameReference("ax"), nameReference("a1"))); // To be compliant, could rewrite either the whole expression, or just the sub-expression. Rewriting larger expressions are preferred - assertEquals(inference.rewriteExpression(multiply(nameReference("a1"), add("b", "c")), symbolBeginsWith("a")), nameReference("a3")); + assertEquals(inference.rewriteExpression(multiply(nameReference("a1"), add("b", "c")), variableBeginsWith("a"), types("a1", "a2", "a3", "b", "c")), nameReference("a3")); } @Test @@ -298,10 +304,10 @@ public void testConstantEqualities() EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) - assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); + assertEquals(inference.rewriteExpression(nameReference("a1"), matchesVariables("a1", "b1"), types("a1", "b1", "c1")), number(1)); // All scope equalities should utilize the constant if possible - EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); + EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesVariables("a1", "b1"), types("a1", "b1", "c1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), @@ -320,7 +326,7 @@ public void testEqualityGeneration() addEquality("c", "d", builder); EqualityInference inference = builder.build(); - Expression scopedCanonical = inference.getScopedCanonical(nameReference("e1"), symbolBeginsWith("a")); + Expression scopedCanonical = inference.getScopedCanonical(nameReference("e1"), variableBeginsWith("a"), types("a1", "b", "c", "d", "e1")); assertEquals(scopedCanonical, nameReference("a1")); } @@ -344,22 +350,22 @@ public void testExpressionsThatMayReturnNullOnNonNullInput() builder.extractInferenceCandidates(equals(nameReference("a"), candidate)); EqualityInference inference = builder.build(); - List equalities = inference.generateEqualitiesPartitionedBy(matchesSymbols("b")).getScopeStraddlingEqualities(); + List equalities = inference.generateEqualitiesPartitionedBy(matchesVariables("b"), types("a", "b", "x")).getScopeStraddlingEqualities(); assertEquals(equalities.size(), 1); assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x")))); } } - private static Predicate matchesSymbolScope(final Predicate symbolScope) + private static Predicate matchesVariableScope(final Predicate variableScope, TypeProvider types) { - return expression -> Iterables.all(SymbolsExtractor.extractUnique(expression), symbolScope); + return expression -> Iterables.all(SymbolsExtractor.extractUniqueVariable(expression, types), variableScope); } - private static Predicate matchesStraddlingScope(final Predicate symbolScope) + private static Predicate matchesStraddlingScope(final Predicate variableScope, TypeProvider types) { return expression -> { - Set symbols = SymbolsExtractor.extractUnique(expression); - return Iterables.any(symbols, symbolScope) && Iterables.any(symbols, not(symbolScope)); + Set variables = SymbolsExtractor.extractUniqueVariable(expression, types); + return Iterables.any(variables, variableScope) && Iterables.any(variables, not(variableScope)); }; } @@ -418,30 +424,30 @@ private static LongLiteral number(long number) return new LongLiteral(String.valueOf(number)); } - private static Predicate matchesSymbols(String... symbols) + private static Predicate matchesVariables(String... variables) { - return matchesSymbols(Arrays.asList(symbols)); + return matchesVariables(Arrays.asList(variables)); } - private static Predicate matchesSymbols(Collection symbols) + private static Predicate matchesVariables(Collection variables) { - final Set symbolSet = symbols.stream() - .map(Symbol::new) + final Set symbolSet = variables.stream() + .map(name -> new VariableReferenceExpression(name, BIGINT)) .collect(toImmutableSet()); return Predicates.in(symbolSet); } - private static Predicate symbolBeginsWith(String... prefixes) + private static Predicate variableBeginsWith(String... prefixes) { - return symbolBeginsWith(Arrays.asList(prefixes)); + return variableBeginsWith(Arrays.asList(prefixes)); } - private static Predicate symbolBeginsWith(final Iterable prefixes) + private static Predicate variableBeginsWith(final Iterable prefixes) { - return symbol -> { + return variable -> { for (String prefix : prefixes) { - if (symbol.getName().startsWith(prefix)) { + if (variable.getName().startsWith(prefix)) { return true; } } @@ -475,4 +481,10 @@ private static Set setCopy(Iterable elements) { return ImmutableSet.copyOf(elements); } + + private static TypeProvider types(String... variables) + { + Map types = Arrays.asList(variables).stream().map(Symbol::new).collect(toImmutableMap(identity(), ignore -> BIGINT)); + return TypeProvider.copyOf(types); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index 8cff058b5870c..66006f00c921f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.Session; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; @@ -53,6 +54,7 @@ import static com.facebook.presto.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS; import static com.facebook.presto.spi.block.SortOrder.ASC_NULLS_LAST; import static com.facebook.presto.spi.predicate.Domain.singleValue; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; @@ -845,7 +847,7 @@ public void testBroadcastCorrelatedSubqueryAvoidsRemoteExchangeBeforeAggregation countOfMatchingNodes( plan, node -> node instanceof AggregationNode - && ((AggregationNode) node).getGroupingKeys().contains(new Symbol("unique")) + && ((AggregationNode) node).getGroupingKeys().contains(new VariableReferenceExpression("unique", BIGINT)) && ((AggregationNode) node).isStreamable()), 1); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java index 01bf876a8c6d7..c9a5050e895a0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java @@ -430,9 +430,9 @@ public void testPredicatePushDownCreatesValidJoin() .on(p -> p.join(INNER, p.filter(new ComparisonExpression(EQUAL, p.symbol("a1").toSymbolReference(), new LongLiteral("1")), - p.values(p.symbol("a1"))), - p.values(p.symbol("b1")), - ImmutableList.of(new EquiJoinClause(p.symbol("a1"), p.symbol("b1"))), + p.values(p.variable("a1"))), + p.values(p.variable("b1")), + ImmutableList.of(new EquiJoinClause(p.variable("a1"), p.variable("b1"))), ImmutableList.of(), Optional.empty(), Optional.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSchedulingOrderVisitor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSchedulingOrderVisitor.java index 0c4b4c4766dfb..4318dcaa7b699 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSchedulingOrderVisitor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSchedulingOrderVisitor.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.TestingColumnHandle; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -43,8 +44,8 @@ public class TestSchedulingOrderVisitor public void testJoinOrder() { PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), METADATA); - TableScanNode a = planBuilder.tableScan(emptyList(), emptyMap()); - TableScanNode b = planBuilder.tableScan(emptyList(), emptyMap()); + TableScanNode a = planBuilder.tableScan(emptyList(), emptyList(), emptyMap()); + TableScanNode b = planBuilder.tableScan(emptyList(), emptyList(), emptyMap()); List order = scheduleOrder(planBuilder.join(JoinNode.Type.INNER, a, b)); assertEquals(order, ImmutableList.of(b.getId(), a.getId())); } @@ -53,8 +54,8 @@ public void testJoinOrder() public void testIndexJoinOrder() { PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), METADATA); - TableScanNode a = planBuilder.tableScan(emptyList(), emptyMap()); - TableScanNode b = planBuilder.tableScan(emptyList(), emptyMap()); + TableScanNode a = planBuilder.tableScan(emptyList(), emptyList(), emptyMap()); + TableScanNode b = planBuilder.tableScan(emptyList(), emptyList(), emptyMap()); List order = scheduleOrder(planBuilder.indexJoin(IndexJoinNode.Type.INNER, a, b)); assertEquals(order, ImmutableList.of(b.getId(), a.getId())); } @@ -63,14 +64,14 @@ public void testIndexJoinOrder() public void testSemiJoinOrder() { PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), METADATA); - Symbol sourceJoin = planBuilder.symbol("sourceJoin"); - TableScanNode a = planBuilder.tableScan(ImmutableList.of(sourceJoin), ImmutableMap.of(sourceJoin, new TestingColumnHandle("sourceJoin"))); - Symbol filteringSource = planBuilder.symbol("filteringSource"); - TableScanNode b = planBuilder.tableScan(ImmutableList.of(filteringSource), ImmutableMap.of(filteringSource, new TestingColumnHandle("filteringSource"))); + VariableReferenceExpression sourceJoin = planBuilder.variable("sourceJoin"); + TableScanNode a = planBuilder.tableScan(ImmutableList.of(new Symbol(sourceJoin.getName())), ImmutableList.of(sourceJoin), ImmutableMap.of(sourceJoin, new TestingColumnHandle("sourceJoin"))); + VariableReferenceExpression filteringSource = planBuilder.variable("filteringSource"); + TableScanNode b = planBuilder.tableScan(ImmutableList.of(new Symbol(filteringSource.getName())), ImmutableList.of(filteringSource), ImmutableMap.of(filteringSource, new TestingColumnHandle("filteringSource"))); List order = scheduleOrder(planBuilder.semiJoin( sourceJoin, filteringSource, - planBuilder.symbol("semiJoinOutput"), + planBuilder.variable("semiJoinOutput"), Optional.empty(), Optional.empty(), a, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java index cd5b9f935b265..5cecd5c4e9540 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java @@ -16,7 +16,6 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.Expression; @@ -26,7 +25,6 @@ import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; @@ -40,8 +38,10 @@ public class TestSortExpressionExtractor { private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); - private static final Set BUILD_SYMBOLS = ImmutableSet.of(new Symbol("b1"), new Symbol("b2")); - private static final Map SYMBOL_TYPES = ImmutableMap.of(new Symbol("b1"), BIGINT, new Symbol("b2"), BIGINT, new Symbol("p1"), BIGINT, new Symbol("p2"), BIGINT); + private static final Set BUILD_VARIABLES = ImmutableSet.of( + new VariableReferenceExpression("b1", BIGINT), + new VariableReferenceExpression("b2", BIGINT)); + private static final TypeProvider TYPES = TypeProvider.viewOf(ImmutableMap.of(new Symbol("b1"), BIGINT, new Symbol("b2"), BIGINT, new Symbol("p1"), BIGINT, new Symbol("p2"), BIGINT)); @Test public void testGetSortExpression() @@ -90,8 +90,8 @@ private void assertNoSortExpression(String expression) private void assertNoSortExpression(Expression expression) { Optional actual = SortExpressionExtractor.extractSortExpression( - BUILD_SYMBOLS, - TRANSLATOR.translate(expression, TypeProvider.copyOf(SYMBOL_TYPES)), + BUILD_VARIABLES, + TRANSLATOR.translate(expression, TYPES), METADATA.getFunctionManager()); assertEquals(actual, Optional.empty()); } @@ -124,10 +124,10 @@ private static void assertGetSortExpression(Expression expression, String expect { Optional expected = Optional.of(new SortExpressionContext( new VariableReferenceExpression(expectedSymbol, BIGINT), - searchExpressions.stream().map(e -> TRANSLATOR.translate(e, TypeProvider.copyOf(SYMBOL_TYPES))).collect(toImmutableList()))); + searchExpressions.stream().map(e -> TRANSLATOR.translate(e, TYPES)).collect(toImmutableList()))); Optional actual = SortExpressionExtractor.extractSortExpression( - BUILD_SYMBOLS, - TRANSLATOR.translate(expression, TypeProvider.copyOf(SYMBOL_TYPES)), + BUILD_VARIABLES, + TRANSLATOR.translate(expression, TYPES), METADATA.getFunctionManager()); assertEquals(actual, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index 44c8f102ee3c4..4c08f27ff0a2e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -85,6 +85,11 @@ public class TestTypeValidator private Symbol columnC; private Symbol columnD; private Symbol columnE; + private VariableReferenceExpression variableA; + private VariableReferenceExpression variableB; + private VariableReferenceExpression variableC; + private VariableReferenceExpression variableD; + private VariableReferenceExpression variableE; @BeforeMethod public void setUp() @@ -96,12 +101,18 @@ public void setUp() columnD = symbolAllocator.newSymbol("d", DATE); columnE = symbolAllocator.newSymbol("e", VarcharType.createVarcharType(3)); // varchar(3), to test type only coercion - Map assignments = ImmutableMap.builder() - .put(columnA, new TestingColumnHandle("a")) - .put(columnB, new TestingColumnHandle("b")) - .put(columnC, new TestingColumnHandle("c")) - .put(columnD, new TestingColumnHandle("d")) - .put(columnE, new TestingColumnHandle("e")) + variableA = new VariableReferenceExpression(columnA.getName(), BIGINT); + variableB = new VariableReferenceExpression(columnB.getName(), INTEGER); + variableC = new VariableReferenceExpression(columnC.getName(), DOUBLE); + variableD = new VariableReferenceExpression(columnD.getName(), DATE); + variableE = new VariableReferenceExpression(columnE.getName(), VarcharType.createVarcharType(3)); + + Map assignments = ImmutableMap.builder() + .put(variableA, new TestingColumnHandle("a")) + .put(variableB, new TestingColumnHandle("b")) + .put(variableC, new TestingColumnHandle("c")) + .put(variableD, new TestingColumnHandle("d")) + .put(variableE, new TestingColumnHandle("e")) .build(); baseTableScan = new TableScanNode( @@ -119,8 +130,8 @@ public void testValidProject() Expression expression1 = new Cast(columnB.toSymbolReference(), StandardTypes.BIGINT); Expression expression2 = new Cast(columnC.toSymbolReference(), StandardTypes.BIGINT); Assignments assignments = Assignments.builder() - .put(symbolAllocator.newSymbol(expression1, BIGINT), expression1) - .put(symbolAllocator.newSymbol(expression2, BIGINT), expression2) + .put(symbolAllocator.newVariable(expression1, BIGINT), expression1) + .put(symbolAllocator.newVariable(expression2, BIGINT), expression2) .build(); PlanNode node = new ProjectNode( newId(), @@ -133,17 +144,16 @@ public void testValidProject() @Test public void testValidUnion() { - Symbol outputSymbol = symbolAllocator.newSymbol("output", DATE); - ListMultimap mappings = ImmutableListMultimap.builder() - .put(outputSymbol, columnD) - .put(outputSymbol, columnD) + VariableReferenceExpression output = symbolAllocator.newVariable("output", DATE); + ListMultimap mappings = ImmutableListMultimap.builder() + .put(output, variableD) + .put(output, variableD) .build(); PlanNode node = new UnionNode( newId(), ImmutableList.of(baseTableScan, baseTableScan), - mappings, - ImmutableList.copyOf(mappings.keySet())); + mappings); assertTypesValid(node); } @@ -152,6 +162,7 @@ public void testValidUnion() public void testValidWindow() { Symbol windowSymbol = symbolAllocator.newSymbol("sum", DOUBLE); + VariableReferenceExpression windowVariable = new VariableReferenceExpression(windowSymbol.getName(), DOUBLE); FunctionHandle functionHandle = FUNCTION_MANAGER.lookupFunction("sum", fromTypes(DOUBLE)); WindowNode.Frame frame = new WindowNode.Frame( @@ -171,7 +182,7 @@ public void testValidWindow() newId(), baseTableScan, specification, - ImmutableMap.of(windowSymbol, function), + ImmutableMap.of(windowVariable, function), Optional.empty(), ImmutableSet.of(), 0); @@ -182,19 +193,19 @@ public void testValidWindow() @Test public void testValidAggregation() { - Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); + VariableReferenceExpression aggregationVariable = symbolAllocator.newVariable("sum", DOUBLE); PlanNode node = new AggregationNode( newId(), baseTableScan, - ImmutableMap.of(aggregationSymbol, new Aggregation( + ImmutableMap.of(aggregationVariable, new Aggregation( FUNCTION_MANAGER.lookupFunction("sum", fromTypes(DOUBLE)), ImmutableList.of(columnC.toSymbolReference()), Optional.empty(), Optional.empty(), false, Optional.empty())), - singleGroupingSet(ImmutableList.of(columnA, columnB)), + singleGroupingSet(ImmutableList.of(variableA, variableB)), ImmutableList.of(), SINGLE, Optional.empty(), @@ -208,22 +219,22 @@ public void testValidTypeOnlyCoercion() { Expression expression = new Cast(columnB.toSymbolReference(), StandardTypes.BIGINT); Assignments assignments = Assignments.builder() - .put(symbolAllocator.newSymbol(expression, BIGINT), expression) - .put(symbolAllocator.newSymbol(columnE.toSymbolReference(), VARCHAR), columnE.toSymbolReference()) // implicit coercion from varchar(3) to varchar + .put(symbolAllocator.newVariable(expression, BIGINT), expression) + .put(symbolAllocator.newVariable(columnE.toSymbolReference(), VARCHAR), columnE.toSymbolReference()) // implicit coercion from varchar(3) to varchar .build(); PlanNode node = new ProjectNode(newId(), baseTableScan, assignments); assertTypesValid(node); } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'expr(_[0-9]+)?' is expected to be bigint, but the actual type is integer") + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of variable 'expr(_[0-9]+)?' is expected to be bigint, but the actual type is integer") public void testInvalidProject() { Expression expression1 = new Cast(columnB.toSymbolReference(), StandardTypes.INTEGER); Expression expression2 = new Cast(columnA.toSymbolReference(), StandardTypes.INTEGER); Assignments assignments = Assignments.builder() - .put(symbolAllocator.newSymbol(expression1, BIGINT), expression1) // should be INTEGER - .put(symbolAllocator.newSymbol(expression1, INTEGER), expression2) + .put(symbolAllocator.newVariable(expression1, BIGINT), expression1) // should be INTEGER + .put(symbolAllocator.newVariable(expression1, INTEGER), expression2) .build(); PlanNode node = new ProjectNode( newId(), @@ -234,22 +245,22 @@ public void testInvalidProject() } // This test will be disable temporarily until we converted Aggregation to use CallExpression - @Test(enabled = false, expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint") + @Test(enabled = false, expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of variable 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint") public void testInvalidAggregationFunctionCall() { - Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); + VariableReferenceExpression aggregationVariable = symbolAllocator.newVariable("sum", DOUBLE); PlanNode node = new AggregationNode( newId(), baseTableScan, - ImmutableMap.of(aggregationSymbol, new Aggregation( + ImmutableMap.of(aggregationVariable, new Aggregation( FUNCTION_MANAGER.lookupFunction("sum", fromTypes(DOUBLE)), ImmutableList.of(columnA.toSymbolReference()), Optional.empty(), Optional.empty(), false, Optional.empty())), - singleGroupingSet(ImmutableList.of(columnA, columnB)), + singleGroupingSet(ImmutableList.of(variableA, variableB)), ImmutableList.of(), SINGLE, Optional.empty(), @@ -258,22 +269,22 @@ public void testInvalidAggregationFunctionCall() assertTypesValid(node); } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint") + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of variable 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint") public void testInvalidAggregationFunctionSignature() { - Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); + VariableReferenceExpression aggregationVariable = symbolAllocator.newVariable("sum", DOUBLE); PlanNode node = new AggregationNode( newId(), baseTableScan, - ImmutableMap.of(aggregationSymbol, new Aggregation( + ImmutableMap.of(aggregationVariable, new Aggregation( FUNCTION_MANAGER.lookupFunction("sum", fromTypes(BIGINT)), // should be DOUBLE ImmutableList.of(columnC.toSymbolReference()), Optional.empty(), Optional.empty(), false, Optional.empty())), - singleGroupingSet(ImmutableList.of(columnA, columnB)), + singleGroupingSet(ImmutableList.of(variableA, variableB)), ImmutableList.of(), SINGLE, Optional.empty(), @@ -282,10 +293,11 @@ public void testInvalidAggregationFunctionSignature() assertTypesValid(node); } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint") + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of variable 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint") public void testInvalidWindowFunctionCall() { Symbol windowSymbol = symbolAllocator.newSymbol("sum", DOUBLE); + VariableReferenceExpression windowVariable = new VariableReferenceExpression(windowSymbol.getName(), DOUBLE); FunctionHandle functionHandle = FUNCTION_MANAGER.lookupFunction("sum", fromTypes(DOUBLE)); WindowNode.Frame frame = new WindowNode.Frame( @@ -305,7 +317,7 @@ public void testInvalidWindowFunctionCall() newId(), baseTableScan, specification, - ImmutableMap.of(windowSymbol, function), + ImmutableMap.of(windowVariable, function), Optional.empty(), ImmutableSet.of(), 0); @@ -313,10 +325,11 @@ public void testInvalidWindowFunctionCall() assertTypesValid(node); } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint") + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of variable 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint") public void testInvalidWindowFunctionSignature() { Symbol windowSymbol = symbolAllocator.newSymbol("sum", DOUBLE); + VariableReferenceExpression windowVariable = new VariableReferenceExpression(windowSymbol.getName(), DOUBLE); FunctionHandle functionHandle = FUNCTION_MANAGER.lookupFunction("sum", fromTypes(BIGINT)); // should be DOUBLE WindowNode.Frame frame = new WindowNode.Frame( @@ -336,7 +349,7 @@ public void testInvalidWindowFunctionSignature() newId(), baseTableScan, specification, - ImmutableMap.of(windowSymbol, function), + ImmutableMap.of(windowVariable, function), Optional.empty(), ImmutableSet.of(), 0); @@ -344,20 +357,19 @@ public void testInvalidWindowFunctionSignature() assertTypesValid(node); } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'output(_[0-9]+)?' is expected to be date, but the actual type is bigint") + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of variable 'output(_[0-9]+)?' is expected to be date, but the actual type is bigint") public void testInvalidUnion() { - Symbol outputSymbol = symbolAllocator.newSymbol("output", DATE); - ListMultimap mappings = ImmutableListMultimap.builder() - .put(outputSymbol, columnD) - .put(outputSymbol, columnA) // should be a symbol with DATE type + VariableReferenceExpression output = symbolAllocator.newVariable("output", DATE); + ListMultimap mappings = ImmutableListMultimap.builder() + .put(output, variableD) + .put(output, variableA) // should be a symbol with DATE type .build(); PlanNode node = new UnionNode( newId(), ImmutableList.of(baseTableScan, baseTableScan), - mappings, - ImmutableList.copyOf(mappings.keySet())); + mappings); assertTypesValid(node); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java index ad62cfb5d92f8..1776d3c21c44c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java @@ -16,13 +16,14 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.OrderBy; +import com.facebook.presto.sql.tree.SymbolReference; import java.util.Map; import java.util.Optional; @@ -42,9 +43,9 @@ public AggregationFunctionMatcher(ExpectedValueProvider callMaker) } @Override - public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public Optional getAssignedVariable(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { - Optional result = Optional.empty(); + Optional result = Optional.empty(); if (!(node instanceof AggregationNode)) { return result; } @@ -52,7 +53,7 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada AggregationNode aggregationNode = (AggregationNode) node; FunctionCall expectedCall = callMaker.getExpectedValue(symbolAliases); - for (Map.Entry assignment : aggregationNode.getAggregations().entrySet()) { + for (Map.Entry assignment : aggregationNode.getAggregations().entrySet()) { if (verifyAggregation(metadata.getFunctionManager(), assignment.getValue(), expectedCall)) { checkState(!result.isPresent(), "Ambiguous function calls in %s", aggregationNode); result = Optional.of(assignment.getKey()); @@ -85,9 +86,9 @@ private static boolean verifyAggregationOrderBy(OrderingScheme orderingScheme, O return false; } for (int i = 0; i < expectedSortOrder.getSortItems().size(); i++) { - Symbol orderingSymbol = orderingScheme.getOrderBy().get(i); - if (expectedSortOrder.getSortItems().get(i).getSortKey().equals(orderingSymbol.toSymbolReference()) && - toSortOrder(expectedSortOrder.getSortItems().get(i)).equals(orderingScheme.getOrdering(orderingSymbol))) { + VariableReferenceExpression orderingVariable = orderingScheme.getOrderBy().get(i); + if (expectedSortOrder.getSortItems().get(i).getSortKey().equals(new SymbolReference(orderingVariable.getName())) && + toSortOrder(expectedSortOrder.getSortItems().get(i)).equals(orderingScheme.getOrdering(orderingVariable))) { continue; } return false; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java index a3bc881984b9b..afb12a0c8c68d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; @@ -25,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; @@ -32,6 +34,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; public class AggregationMatcher implements Matcher @@ -63,7 +66,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); AggregationNode aggregationNode = (AggregationNode) node; - if (groupId.isPresent() != aggregationNode.getGroupIdSymbol().isPresent()) { + if (groupId.isPresent() != aggregationNode.getGroupIdVariable().isPresent()) { return NO_MATCH; } @@ -79,19 +82,19 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return NO_MATCH; } - List aggregationsWithMask = aggregationNode.getAggregations() + List aggregationsWithMask = aggregationNode.getAggregations() .entrySet() .stream() .filter(entry -> entry.getValue().isDistinct()) - .map(entry -> entry.getKey()) + .map(Map.Entry::getKey) .collect(Collectors.toList()); if (aggregationsWithMask.size() != masks.keySet().size()) { return NO_MATCH; } - for (Symbol symbol : aggregationsWithMask) { - if (!masks.keySet().contains(symbol)) { + for (VariableReferenceExpression variable : aggregationsWithMask) { + if (!masks.keySet().contains(new Symbol(variable.getName()))) { return NO_MATCH; } } @@ -100,25 +103,26 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return NO_MATCH; } - if (!matches(preGroupedSymbols, aggregationNode.getPreGroupedSymbols(), symbolAliases)) { + if (!matches(preGroupedSymbols, aggregationNode.getPreGroupedVariables(), symbolAliases)) { return NO_MATCH; } return match(); } - static boolean matches(Collection expectedAliases, Collection actualSymbols, SymbolAliases symbolAliases) + static boolean matches(Collection expectedAliases, Collection actualVariables, SymbolAliases symbolAliases) { - if (expectedAliases.size() != actualSymbols.size()) { + if (expectedAliases.size() != actualVariables.size()) { return false; } - List expectedSymbols = expectedAliases + List expectedSymbolNames = expectedAliases .stream() - .map(alias -> new Symbol(symbolAliases.get(alias).getName())) + .map(alias -> symbolAliases.get(alias).getName()) .collect(toImmutableList()); - for (Symbol symbol : expectedSymbols) { - if (!actualSymbols.contains(symbol)) { + Set actualVariableNames = actualVariables.stream().map(VariableReferenceExpression::getName).collect(toImmutableSet()); + for (String symbolName : expectedSymbolNames) { + if (!actualVariableNames.contains(symbolName)) { return false; } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java index f9be1080607e1..7f37aee9207cb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java @@ -16,8 +16,9 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.tree.SymbolReference; import java.util.Optional; @@ -54,11 +55,11 @@ public boolean shapeMatches(PlanNode node) @Override public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { - Optional symbol = matcher.getAssignedSymbol(node, session, metadata, symbolAliases); - if (symbol.isPresent() && alias.isPresent()) { - return match(alias.get(), symbol.get().toSymbolReference()); + Optional variable = matcher.getAssignedVariable(node, session, metadata, symbolAliases); + if (variable.isPresent() && alias.isPresent()) { + return match(alias.get(), new SymbolReference(variable.get().getName())); } - return new MatchResult(symbol.isPresent()); + return new MatchResult(variable.isPresent()); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasPresent.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasPresent.java index 40a9da90be0a7..cc1317c7ee9c1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasPresent.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasPresent.java @@ -15,11 +15,12 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.PlanNode; import java.util.Optional; +import static com.facebook.presto.type.UnknownType.UNKNOWN; import static java.util.Objects.requireNonNull; /** @@ -36,10 +37,10 @@ class AliasPresent } @Override - public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public Optional getAssignedVariable(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { return symbolAliases.getOptional(alias) - .map(Symbol::from); + .map(alias -> new VariableReferenceExpression(alias.getName(), UNKNOWN)); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java index eafcbe89034f3..81098e9f312d3 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -27,7 +27,7 @@ public class AssignUniqueIdMatcher implements RvalueMatcher { @Override - public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public Optional getAssignedVariable(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { if (!(node instanceof AssignUniqueId)) { return Optional.empty(); @@ -35,7 +35,7 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada AssignUniqueId assignUniqueIdNode = (AssignUniqueId) node; - return Optional.of(assignUniqueIdNode.getIdColumn()); + return Optional.of(assignUniqueIdNode.getIdVariable()); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java index ae4eaafda8477..af9bf6b9aee17 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java @@ -16,21 +16,23 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.PlanNode; import java.util.Set; import java.util.function.Function; +import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; public abstract class BaseStrictSymbolsMatcher implements Matcher { - private final Function> getActual; + private final Function> getActual; - public BaseStrictSymbolsMatcher(Function> getActual) + public BaseStrictSymbolsMatcher(Function> getActual) { this.getActual = requireNonNull(getActual, "getActual is null"); } @@ -51,8 +53,16 @@ public boolean shapeMatches(PlanNode node) public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); - return new MatchResult(getActual.apply(node).equals(getExpectedSymbols(node, session, metadata, symbolAliases))); + return new MatchResult(match(getActual.apply(node), getExpectedVariables(node, session, metadata, symbolAliases))); } - protected abstract Set getExpectedSymbols(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases); + protected abstract Set getExpectedVariables(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases); + + boolean match(Set actual, Set expected) + { + if (expected.stream().anyMatch(variable -> variable.getType().equals(UNKNOWN))) { + return actual.stream().map(VariableReferenceExpression::getName).collect(toImmutableSet()).equals(expected.stream().map(VariableReferenceExpression::getName).collect(toImmutableSet())); + } + return actual.equals(expected); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ColumnReference.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ColumnReference.java index a9c96b88b319a..1fe8a6bfcb10f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ColumnReference.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ColumnReference.java @@ -18,7 +18,7 @@ import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.metadata.TableMetadata; import com.facebook.presto.spi.ColumnHandle; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TableScanNode; @@ -43,10 +43,10 @@ public ColumnReference(String tableName, String columnName) } @Override - public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public Optional getAssignedVariable(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { TableHandle tableHandle; - Map assignments; + Map assignments; if (node instanceof TableScanNode) { TableScanNode tableScanNode = (TableScanNode) node; @@ -74,13 +74,13 @@ else if (node instanceof IndexSourceNode) { checkState(columnHandle.isPresent(), format("Table %s doesn't have column %s. Typo in test?", tableName, columnName)); - return getAssignedSymbol(assignments, columnHandle.get()); + return getAssignedVariable(assignments, columnHandle.get()); } - private Optional getAssignedSymbol(Map assignments, ColumnHandle columnHandle) + private Optional getAssignedVariable(Map assignments, ColumnHandle columnHandle) { - Optional result = Optional.empty(); - for (Map.Entry entry : assignments.entrySet()) { + Optional result = Optional.empty(); + for (Map.Entry entry : assignments.entrySet()) { if (entry.getValue().equals(columnHandle)) { checkState(!result.isPresent(), "Multiple ColumnHandles found for %s:%s in table scan assignments", tableName, columnName); result = Optional.of(entry.getKey()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java index be20971039915..d5dd3a95e60ce 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java @@ -16,10 +16,11 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.tree.SymbolReference; import java.util.List; @@ -53,21 +54,21 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); - List actualCorrelation = getCorrelation(node); + List actualCorrelation = getCorrelation(node); if (this.correlation.size() != actualCorrelation.size()) { return NO_MATCH; } int i = 0; for (String alias : this.correlation) { - if (!symbolAliases.get(alias).equals(actualCorrelation.get(i++).toSymbolReference())) { + if (!symbolAliases.get(alias).equals(new SymbolReference(actualCorrelation.get(i++).getName()))) { return NO_MATCH; } } return match(); } - private List getCorrelation(PlanNode node) + private List getCorrelation(PlanNode node) { if (node instanceof ApplyNode) { return ((ApplyNode) node).getCorrelation(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java index 23cc0aa5e5918..a3e13f31dc3f5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.sql.planner.assertions; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.JoinNode; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static java.util.Objects.requireNonNull; class EquiJoinClauseProvider @@ -31,7 +33,9 @@ class EquiJoinClauseProvider public JoinNode.EquiJoinClause getExpectedValue(SymbolAliases aliases) { - return new JoinNode.EquiJoinClause(left.toSymbol(aliases), right.toSymbol(aliases)); + return new JoinNode.EquiJoinClause( + new VariableReferenceExpression(left.toSymbol(aliases).getName(), BIGINT), + new VariableReferenceExpression(right.toSymbol(aliases).getName(), BIGINT)); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java index 83b363b1a303d..c722924be6257 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java @@ -15,8 +15,8 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -51,11 +51,11 @@ private Expression expression(String sql) } @Override - public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public Optional getAssignedVariable(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { - Optional result = Optional.empty(); + Optional result = Optional.empty(); ImmutableList.Builder matchesBuilder = ImmutableList.builder(); - Map assignments = getAssignments(node); + Map assignments = getAssignments(node); if (assignments == null) { return result; @@ -63,7 +63,7 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); - for (Map.Entry assignment : assignments.entrySet()) { + for (Map.Entry assignment : assignments.entrySet()) { if (verifier.process(assignment.getValue(), expression)) { result = Optional.of(assignment.getKey()); matchesBuilder.add(assignment.getValue()); @@ -76,7 +76,7 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada return result; } - private static Map getAssignments(PlanNode node) + private static Map getAssignments(PlanNode node) { if (node instanceof ProjectNode) { ProjectNode projectNode = (ProjectNode) node; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java index 861d84c36fb78..9976876b70e16 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java @@ -16,9 +16,10 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.tree.SymbolReference; import java.util.List; import java.util.Map; @@ -54,8 +55,8 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); GroupIdNode groudIdNode = (GroupIdNode) node; - List> actualGroups = groudIdNode.getGroupingSets(); - List actualAggregationArguments = groudIdNode.getAggregationArguments(); + List> actualGroups = groudIdNode.getGroupingSets(); + List actualAggregationArguments = groudIdNode.getAggregationArguments(); if (actualGroups.size() != groups.size()) { return NO_MATCH; @@ -71,7 +72,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return NO_MATCH; } - return match(groupIdAlias, groudIdNode.getGroupIdSymbol().toSymbolReference()); + return match(groupIdAlias, new SymbolReference(groudIdNode.getGroupIdVariable().getName())); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java index 3a3f408e710c3..6199c91e08018 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java @@ -21,7 +21,7 @@ import com.facebook.presto.sql.planner.plan.JoinNode.DistributionType; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.tree.Expression; -import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Optional; @@ -103,10 +103,11 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses * Have to use order-independent comparison; there are no guarantees what order * the equi criteria will have after planning and optimizing. */ - Set actual = ImmutableSet.copyOf(joinNode.getCriteria()); - Set expected = + Set> actual = joinNode.getCriteria().stream().map(criteria -> ImmutableList.of(criteria.getLeft().getName(), criteria.getRight().getName())).collect(toImmutableSet()); + Set> expected = equiCriteria.stream() .map(maker -> maker.getExpectedValue(symbolAliases)) + .map(criteria -> ImmutableList.of(criteria.getLeft().getName(), criteria.getRight().getName())) .collect(toImmutableSet()); return new MatchResult(expected.equals(actual)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java index 893866ce845ba..e9c53b33d3e66 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java @@ -16,10 +16,12 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import java.util.List; import java.util.Optional; @@ -42,7 +44,7 @@ public MarkDistinctMatcher(PlanTestSymbol markerSymbol, List dis { this.markerSymbol = requireNonNull(markerSymbol, "markerSymbol is null"); this.distinctSymbols = ImmutableList.copyOf(distinctSymbols); - this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + this.hashSymbol = requireNonNull(hashSymbol, "hashVariable is null"); } @Override @@ -57,16 +59,20 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); MarkDistinctNode markDistinctNode = (MarkDistinctNode) node; - if (!markDistinctNode.getHashSymbol().equals(hashSymbol.map(alias -> alias.toSymbol(symbolAliases)))) { + if (!markDistinctNode.getHashVariable().map(variable -> new Symbol(variable.getName())).equals(hashSymbol.map(alias -> alias.toSymbol(symbolAliases)))) { return NO_MATCH; } - if (!ImmutableSet.copyOf(markDistinctNode.getDistinctSymbols()) + if (!markDistinctNode.getDistinctVariables() + .stream() + .map(VariableReferenceExpression::getName) + .map(Symbol::new) + .collect(toImmutableSet()) .equals(distinctSymbols.stream().map(alias -> alias.toSymbol(symbolAliases)).collect(toImmutableSet()))) { return NO_MATCH; } - return match(markerSymbol.toString(), markDistinctNode.getMarkerSymbol().toSymbolReference()); + return match(markerSymbol.toString(), new SymbolReference(markDistinctNode.getMarkerVariable().getName())); } @Override @@ -75,7 +81,7 @@ public String toString() return toStringHelper(this) .add("markerSymbol", markerSymbol) .add("distinctSymbols", distinctSymbols) - .add("hashSymbol", hashSymbol) + .add("hashVariable", hashSymbol) .toString(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java index a86290e66c505..c36d129d3ac73 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java @@ -89,14 +89,14 @@ public void matches(PlanMatchPattern pattern) { PlanNode actual = optimizer.optimize(plan, session, types, new SymbolAllocator(), idAllocator, WarningCollector.NOOP); - if (!ImmutableSet.copyOf(plan.getOutputSymbols()).equals(ImmutableSet.copyOf(actual.getOutputSymbols()))) { + if (!ImmutableSet.copyOf(plan.getOutputVariables()).equals(ImmutableSet.copyOf(actual.getOutputVariables()))) { fail(String.format( "%s: output schema of transformed and original plans are not equivalent\n" + "\texpected: %s\n" + "\tactual: %s", optimizer.getClass().getName(), - plan.getOutputSymbols(), - actual.getOutputSymbols())); + plan.getOutputVariables(), + actual.getOutputVariables())); } inTransaction(session -> { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java index b599b6ef55c61..bdbe188206d8e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java @@ -16,9 +16,10 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import java.util.List; @@ -51,9 +52,9 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses for (String alias : aliases) { Expression expression = symbolAliases.get(alias); boolean found = false; - while (i < node.getOutputSymbols().size()) { - Symbol outputSymbol = node.getOutputSymbols().get(i++); - if (expression.equals(outputSymbol.toSymbolReference())) { + while (i < node.getOutputVariables().size()) { + VariableReferenceExpression outputVariable = node.getOutputVariables().get(i++); + if (expression.equals(new SymbolReference(outputVariable.getName()))) { found = true; break; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index bb962b5ab9a53..d671a0036791b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -639,11 +639,6 @@ public PlanMatchPattern withExactAssignedOutputs(Collection expectedAliases) { matchers.add(new StrictAssignedSymbolsMatcher(actualAssignments(), expectedAliases)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java index b0e619507c9a8..255515e63bb79 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java @@ -16,7 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.Assignments; @@ -24,6 +24,7 @@ import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.SymbolReference; import java.util.List; @@ -52,8 +53,8 @@ final class PlanMatchingVisitor @Override public MatchResult visitExchange(ExchangeNode node, PlanMatchPattern pattern) { - List> allInputs = node.getInputs(); - List outputs = node.getOutputSymbols(); + List> allInputs = node.getInputs(); + List outputs = node.getOutputVariables(); MatchResult result = super.visitExchange(node, pattern); @@ -62,10 +63,10 @@ public MatchResult visitExchange(ExchangeNode node, PlanMatchPattern pattern) } SymbolAliases newAliases = result.getAliases(); - for (List inputs : allInputs) { + for (List inputs : allInputs) { Assignments.Builder assignments = Assignments.builder(); for (int i = 0; i < inputs.size(); ++i) { - assignments.put(outputs.get(i), inputs.get(i).toSymbolReference()); + assignments.put(outputs.get(i), new SymbolReference(inputs.get(i).getName())); } newAliases = newAliases.updateAssignments(assignments.build()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowNumberMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowNumberMatcher.java index 69315a349932b..2bf3cfe5518b0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowNumberMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowNumberMatcher.java @@ -16,6 +16,8 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; @@ -47,7 +49,7 @@ private RowNumberMatcher( this.partitionBy = requireNonNull(partitionBy, "partitionBy is null"); this.maxRowCountPerPartition = requireNonNull(maxRowCountPerPartition, "maxRowCountPerPartition is null"); this.rowNumberSymbol = requireNonNull(rowNumberSymbol, "rowNumberSymbol is null"); - this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + this.hashSymbol = requireNonNull(hashSymbol, "hashVariable is null"); } @Override @@ -66,8 +68,9 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses if (!partitionBy .map(expectedPartitionBy -> expectedPartitionBy.stream() .map(alias -> alias.toSymbol(symbolAliases)) + .map(Symbol::getName) .collect(toImmutableList()) - .equals(rowNumberNode.getPartitionBy())) + .equals(rowNumberNode.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList()))) .orElse(true)) { return NO_MATCH; } @@ -80,8 +83,8 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses if (!rowNumberSymbol .map(expectedRowNumberSymbol -> - expectedRowNumberSymbol.toSymbol(symbolAliases) - .equals(rowNumberNode.getRowNumberSymbol())) + expectedRowNumberSymbol.toSymbol(symbolAliases).getName() + .equals(rowNumberNode.getRowNumberVariable().getName())) .orElse(true)) { return NO_MATCH; } @@ -90,7 +93,8 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses .map(expectedHashSymbol -> expectedHashSymbol .map(symbolAlias -> symbolAlias.toSymbol(symbolAliases)) - .equals(rowNumberNode.getHashSymbol())) + .map(Symbol::getName) + .equals(rowNumberNode.getHashVariable().map(VariableReferenceExpression::getName))) .orElse(true)) { return NO_MATCH; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RvalueMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RvalueMatcher.java index 69118f7c521e7..cf5ffd1dd73db 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RvalueMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RvalueMatcher.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.PlanNode; import java.util.Optional; @@ -32,5 +32,5 @@ public interface RvalueMatcher * The assigned symbol is identified by matching the value on the right side of the assignment; * the rvalue. If no match is found in the node, getAssignedSymbol must return Optional.empty(). */ - Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases); + Optional getAssignedVariable(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SemiJoinMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SemiJoinMatcher.java index 8f31c66540ce9..2791991882db6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SemiJoinMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SemiJoinMatcher.java @@ -18,6 +18,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.tree.SymbolReference; import java.util.Optional; @@ -55,8 +56,8 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); SemiJoinNode semiJoinNode = (SemiJoinNode) node; - if (!(symbolAliases.get(sourceSymbolAlias).equals(semiJoinNode.getSourceJoinSymbol().toSymbolReference()) && - symbolAliases.get(filteringSymbolAlias).equals(semiJoinNode.getFilteringSourceJoinSymbol().toSymbolReference()))) { + if (!(symbolAliases.get(sourceSymbolAlias).equals(new SymbolReference(semiJoinNode.getSourceJoinVariable().getName())) && + symbolAliases.get(filteringSymbolAlias).equals(new SymbolReference(semiJoinNode.getFilteringSourceJoinVariable().getName())))) { return NO_MATCH; } @@ -64,7 +65,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return NO_MATCH; } - return match(outputAlias, semiJoinNode.getSemiJoinOutput().toSymbolReference()); + return match(outputAlias, new SymbolReference(semiJoinNode.getSemiJoinOutput().getName())); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java index aa3c3eb02b44f..300be38e191ec 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.spi.block.SortOrder; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.plan.WindowNode; import com.google.common.collect.ImmutableList; @@ -23,9 +24,11 @@ import java.util.Map; import java.util.Optional; +import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; public class SpecificationProvider @@ -53,18 +56,18 @@ public WindowNode.Specification getExpectedValue(SymbolAliases aliases) orderingScheme = Optional.of(new OrderingScheme( orderBy .stream() - .map(alias -> alias.toSymbol(aliases)) + .map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), UNKNOWN)) .collect(toImmutableList()), orderings .entrySet() .stream() - .collect(toImmutableMap(entry -> entry.getKey().toSymbol(aliases), Map.Entry::getValue)))); + .collect(toImmutableMap(entry -> new VariableReferenceExpression(entry.getKey().toSymbol(aliases).getName(), UNKNOWN), Map.Entry::getValue)))); } return new WindowNode.Specification( partitionBy .stream() - .map(alias -> alias.toSymbol(aliases)) + .map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), UNKNOWN)) .collect(toImmutableList()), orderingScheme); } @@ -78,4 +81,26 @@ public String toString() .add("orderings", this.orderings) .toString(); } + + /* + * Since plan matching is done through SymbolAlias, which does not include type information, we cannot directly use + * VariableReferenceExpression::equals to check whether two specification are equivalent once they include VariableReferenceExpression. + * TODO Directly use equals once SymbolAlias is converted to something with type information. + */ + public static boolean matchSpecification(WindowNode.Specification actual, WindowNode.Specification expected) + { + return actual.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList()) + .equals(expected.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList())) && + actual.getOrderingScheme().map(orderingScheme -> orderingScheme.getOrderBy().stream() + .map(VariableReferenceExpression::getName) + .collect(toImmutableSet()) + .equals(expected.getOrderingScheme().get().getOrderBy().stream() + .map(VariableReferenceExpression::getName) + .collect(toImmutableSet())) && + orderingScheme.getOrderings().entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)) + .equals(expected.getOrderingScheme().get().getOrderings().entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)))) + .orElse(true); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/StrictAssignedSymbolsMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/StrictAssignedSymbolsMatcher.java index 23093dccf2e31..9b86f8b7cf3a7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/StrictAssignedSymbolsMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/StrictAssignedSymbolsMatcher.java @@ -15,8 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.google.common.collect.ImmutableSet; @@ -34,18 +33,18 @@ public class StrictAssignedSymbolsMatcher { private final Collection getExpected; - public StrictAssignedSymbolsMatcher(Function> getActual, Collection getExpected) + public StrictAssignedSymbolsMatcher(Function> getActual, Collection getExpected) { super(getActual); this.getExpected = requireNonNull(getExpected, "getExpected is null"); } @Override - protected Set getExpectedSymbols(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + protected Set getExpectedVariables(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { - ImmutableSet.Builder expected = ImmutableSet.builder(); + ImmutableSet.Builder expected = ImmutableSet.builder(); for (RvalueMatcher matcher : getExpected) { - Optional assigned = matcher.getAssignedSymbol(node, session, metadata, symbolAliases); + Optional assigned = matcher.getAssignedVariable(node, session, metadata, symbolAliases); if (!assigned.isPresent()) { return null; } @@ -56,14 +55,9 @@ protected Set getExpectedSymbols(PlanNode node, Session session, Metadat return expected.build(); } - public static Function> actualAssignments() + public static Function> actualAssignments() { - return node -> ((ProjectNode) node).getAssignments().getSymbols(); - } - - public static Function> actualSubqueryAssignments() - { - return node -> ((ApplyNode) node).getSubqueryAssignments().getSymbols(); + return node -> ((ProjectNode) node).getAssignments().getVariables(); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/StrictSymbolsMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/StrictSymbolsMatcher.java index caa97ed311ec9..5faf3b06adcca 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/StrictSymbolsMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/StrictSymbolsMatcher.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.ImmutableSet; @@ -23,6 +23,7 @@ import java.util.Set; import java.util.function.Function; +import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; @@ -32,24 +33,24 @@ public class StrictSymbolsMatcher { private final List expectedAliases; - public StrictSymbolsMatcher(Function> getActual, List expectedAliases) + public StrictSymbolsMatcher(Function> getActual, List expectedAliases) { super(getActual); this.expectedAliases = requireNonNull(expectedAliases, "expectedAliases is null"); } @Override - protected Set getExpectedSymbols(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + protected Set getExpectedVariables(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { return expectedAliases.stream() .map(symbolAliases::get) - .map(Symbol::from) + .map(symbolReference -> new VariableReferenceExpression(symbolReference.getName(), UNKNOWN)) .collect(toImmutableSet()); } - public static Function> actualOutputs() + public static Function> actualOutputs() { - return node -> ImmutableSet.copyOf(node.getOutputSymbols()); + return node -> ImmutableSet.copyOf(node.getOutputVariables()); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java index 1ec113d359b06..f3463ece9cca3 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.assertions; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; @@ -97,13 +97,13 @@ private static String toKey(String alias) private Map getUpdatedAssignments(Assignments assignments) { ImmutableMap.Builder mapUpdate = ImmutableMap.builder(); - for (Map.Entry assignment : assignments.getMap().entrySet()) { + for (Map.Entry assignment : assignments.getMap().entrySet()) { for (Map.Entry existingAlias : map.entrySet()) { if (assignment.getValue().equals(existingAlias.getValue())) { // Simple symbol rename - mapUpdate.put(existingAlias.getKey(), assignment.getKey().toSymbolReference()); + mapUpdate.put(existingAlias.getKey(), new SymbolReference(assignment.getKey().getName())); } - else if (assignment.getKey().toSymbolReference().equals(existingAlias.getValue())) { + else if (new SymbolReference(assignment.getKey().getName()).equals(existingAlias.getValue())) { /* * Special case for nodes that can alias symbols in the node's assignment map. * In this case, we've already added the alias in the map, but we won't include it diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolCardinalityMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolCardinalityMatcher.java index f3da059774f3c..4376bd5bc988a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolCardinalityMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolCardinalityMatcher.java @@ -39,7 +39,7 @@ public boolean shapeMatches(PlanNode node) @Override public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { - return new MatchResult(node.getOutputSymbols().size() == numberOfSymbols); + return new MatchResult(node.getOutputVariables().size() == numberOfSymbols); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableWriterMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableWriterMatcher.java index 9f676f6ce5395..0555b2d24c761 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableWriterMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableWriterMatcher.java @@ -17,6 +17,7 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; @@ -60,7 +61,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses if (!columns.stream() .map(s -> Symbol.from(symbolAliases.get(s))) .collect(toImmutableList()) - .equals(tableWriterNode.getColumns())) { + .equals(tableWriterNode.getColumns().stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableList()))) { return NO_MATCH; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java index fcd3aefb14c94..1eb26d7142a73 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java @@ -17,6 +17,8 @@ import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.block.SortOrder; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -28,6 +30,7 @@ import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; import static com.facebook.presto.sql.planner.assertions.MatchResult.match; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; @@ -52,7 +55,7 @@ private TopNRowNumberMatcher( this.rowNumberSymbol = requireNonNull(rowNumberSymbol, "rowNumberSymbol is null"); this.maxRowCountPerPartition = requireNonNull(maxRowCountPerPartition, "maxRowCountPerPartition is null"); this.partial = requireNonNull(partial, "partial is null"); - this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + this.hashSymbol = requireNonNull(hashSymbol, "hashVariable is null"); } @Override @@ -69,17 +72,15 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses TopNRowNumberNode topNRowNumberNode = (TopNRowNumberNode) node; if (!specification - .map(expectedSpecification -> - expectedSpecification.getExpectedValue(symbolAliases) - .equals(topNRowNumberNode.getSpecification())) + .map(expectedSpecification -> matchSpecification(topNRowNumberNode.getSpecification(), expectedSpecification.getExpectedValue(symbolAliases))) .orElse(true)) { return NO_MATCH; } if (!rowNumberSymbol .map(expectedRowNumberSymbol -> - expectedRowNumberSymbol.toSymbol(symbolAliases) - .equals(topNRowNumberNode.getRowNumberSymbol())) + expectedRowNumberSymbol.toSymbol(symbolAliases).getName() + .equals(topNRowNumberNode.getRowNumberVariable().getName())) .orElse(true)) { return NO_MATCH; } @@ -100,7 +101,8 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses .map(expectedHashSymbol -> expectedHashSymbol .map(symbolAlias -> symbolAlias.toSymbol(symbolAliases)) - .equals(topNRowNumberNode.getHashSymbol())) + .map(Symbol::getName) + .equals(topNRowNumberNode.getHashVariable().map(VariableReferenceExpression::getName))) .orElse(true)) { return NO_MATCH; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Util.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Util.java index 211d44ddf6beb..d7a988b57aa59 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Util.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Util.java @@ -77,7 +77,7 @@ static boolean orderingSchemeMatches(List expectedOrderBy, OrderingSch for (int i = 0; i < expectedOrderBy.size(); ++i) { Ordering ordering = expectedOrderBy.get(i); Symbol symbol = Symbol.from(symbolAliases.get(ordering.getField())); - if (!symbol.equals(orderingScheme.getOrderBy().get(i))) { + if (!symbol.equals(new Symbol(orderingScheme.getOrderBy().get(i).getName()))) { return false; } if (!ordering.getSortOrder().equals(orderingScheme.getOrdering(symbol))) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java index ab65370f13f4e..a7cde9cc83fae 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java @@ -25,6 +25,7 @@ import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import io.airlift.slice.Slice; @@ -63,7 +64,7 @@ public ValuesMatcher( public boolean shapeMatches(PlanNode node) { return (node instanceof ValuesNode) && - expectedOutputSymbolCount.map(Integer.valueOf(node.getOutputSymbols().size())::equals).orElse(true); + expectedOutputSymbolCount.map(Integer.valueOf(node.getOutputVariables().size())::equals).orElse(true); } @Override @@ -101,7 +102,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses } return match(SymbolAliases.builder() - .putAll(Maps.transformValues(outputSymbolAliases, index -> valuesNode.getOutputSymbols().get(index).toSymbolReference())) + .putAll(Maps.transformValues(outputSymbolAliases, index -> new SymbolReference(valuesNode.getOutputVariables().get(index).getName()))) .build()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java index 439327c6bd0c3..05b95d761a5cb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.assertions; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType; import com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType; @@ -20,6 +21,7 @@ import java.util.Optional; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; @@ -57,9 +59,9 @@ public WindowNode.Frame getExpectedValue(SymbolAliases aliases) return new WindowNode.Frame( type, startType, - startValue.map(alias -> alias.toSymbol(aliases)), + startValue.map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), BIGINT)), endType, - endValue.map(alias -> alias.toSymbol(aliases)), + endValue.map(alias -> new VariableReferenceExpression(alias.toSymbol(aliases).getName(), BIGINT)), originalStartValue.map(Expression::toString), originalEndValue.map(Expression::toString)); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java index 71fc78f14b271..36d0dd1ac6c42 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java @@ -17,7 +17,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -59,9 +59,9 @@ public WindowFunctionMatcher( } @Override - public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public Optional getAssignedVariable(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { - Optional result = Optional.empty(); + Optional result = Optional.empty(); if (!(node instanceof WindowNode)) { return result; } @@ -71,7 +71,7 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada FunctionCall expectedCall = callMaker.getExpectedValue(symbolAliases); Optional expectedFrame = frameMaker.map(maker -> maker.getExpectedValue(symbolAliases)); - List matchedOutputs = windowNode.getWindowFunctions().entrySet().stream() + List matchedOutputs = windowNode.getWindowFunctions().entrySet().stream() .filter(assignment -> { if (!expectedCall.getName().equals(QualifiedName.of(metadata.getFunctionManager().getFunctionMetadata(assignment.getValue().getFunctionCall().getFunctionHandle()).getName()))) { return false; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java index 65e80689f6945..ec6c4c081b8e8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java @@ -18,6 +18,8 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.FunctionCall; @@ -31,6 +33,7 @@ import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; import static com.facebook.presto.sql.planner.assertions.MatchResult.match; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -76,15 +79,13 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses .map(expectedInputs -> expectedInputs.stream() .map(alias -> alias.toSymbol(symbolAliases)) .collect(toImmutableSet()) - .equals(windowNode.getPrePartitionedInputs())) + .equals(windowNode.getPrePartitionedInputs().stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableSet()))) .orElse(true)) { return NO_MATCH; } if (!specification - .map(expectedSpecification -> - expectedSpecification.getExpectedValue(symbolAliases) - .equals(windowNode.getSpecification())) + .map(expectedSpecification -> matchSpecification(windowNode.getSpecification(), expectedSpecification.getExpectedValue(symbolAliases))) .orElse(true)) { return NO_MATCH; } @@ -98,7 +99,8 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses if (!hashSymbol .map(expectedHashSymbol -> expectedHashSymbol .map(alias -> alias.toSymbol(symbolAliases)) - .equals(windowNode.getHashSymbol())) + .map(Symbol::getName) + .equals(windowNode.getHashVariable().map(VariableReferenceExpression::getName))) .orElse(true)) { return NO_MATCH; } @@ -204,16 +206,7 @@ public Builder addFunction( } /** - * Matches only if WindowNode.getHashSymbol() is an empty option. - */ - public Builder hashSymbol() - { - this.hashSymbol = Optional.of(Optional.empty()); - return this; - } - - /** - * Matches only if WindowNode.getHashSymbol() is a non-empty option containing hashSymbol. + * Matches only if WindowNode.getHashVariable() is a non-empty option containing hashVariable. */ public Builder hashSymbol(String hashSymbol) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java index cd16340416e61..dc203ba9e466c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java @@ -107,13 +107,13 @@ public Result apply(ProjectNode project, Captures captures, Context context) if (isIdentityProjection(project)) { return Result.ofPlanNode(project.getSource()); } - PlanNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), project, Assignments.identity(project.getOutputSymbols())); + PlanNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), project, Assignments.identity(project.getOutputVariables())); return Result.ofPlanNode(projectNode); } private static boolean isIdentityProjection(ProjectNode project) { - return ImmutableSet.copyOf(project.getOutputSymbols()).equals(ImmutableSet.copyOf(project.getSource().getOutputSymbols())); + return ImmutableSet.copyOf(project.getOutputVariables()).equals(ImmutableSet.copyOf(project.getSource().getOutputVariables())); } } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMemo.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMemo.java index 942c11d29d0ed..bdf4cacb5ed00 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMemo.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMemo.java @@ -17,7 +17,7 @@ import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -300,7 +300,7 @@ public List getSources() } @Override - public List getOutputSymbols() + public List getOutputVariables() { return ImmutableList.of(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java index 50a3d14687e34..2f8cde762a7f9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java @@ -58,15 +58,15 @@ public void testBasic() .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("c")), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE_STREAMING, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("b")), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a")))))); + p.values(p.variable("a")))))); })) .matches( aggregation( @@ -113,15 +113,15 @@ public void testNoInputCount() .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("c")), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE_STREAMING, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b"), expression("count(*)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("b")), expression("count(*)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a")))))); + p.values(p.variable("a")))))); })) .matches( aggregation( @@ -166,7 +166,7 @@ public void testMultipleExchanges() .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("c")), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE_STREAMING, @@ -174,9 +174,9 @@ public void testMultipleExchanges() ExchangeNode.Scope.REMOTE_STREAMING, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("b")), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a"))))))); + p.values(p.variable("a"))))))); })) .matches( aggregation( @@ -220,15 +220,15 @@ public void testSessionDisable() .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("c")), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE_STREAMING, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("b")), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a")))))); + p.values(p.variable("a")))))); })) .doesNotFire(); } @@ -244,15 +244,15 @@ public void testNoLocalParallel() .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("c")), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE_STREAMING, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("b")), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a")))))); + p.values(p.variable("a")))))); })) .matches( aggregation( @@ -285,17 +285,17 @@ public void testWithGroups() .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { - af.singleGroupingSet(p.symbol("c")) + af.singleGroupingSet(p.variable("c")) .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("c")), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE_STREAMING, - p.aggregation(ap -> ap.singleGroupingSet(p.symbol("b")) + p.aggregation(ap -> ap.singleGroupingSet(p.variable("b")) .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("b")), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a")))))); + p.values(p.variable("a")))))); })) .doesNotFire(); } @@ -311,17 +311,17 @@ public void testInterimProject() .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("c")), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE_STREAMING, p.project( - Assignments.identity(p.symbol("b")), + Assignments.identity(p.variable("b")), p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("b")), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a"))))))); + p.values(p.variable("a"))))))); })) .matches( aggregation( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java index 155128415bb63..1481775245915 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java @@ -15,11 +15,11 @@ import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.PlanNodeStatsEstimate; -import com.facebook.presto.cost.SymbolStatsEstimate; import com.facebook.presto.cost.TaskCountEstimator; +import com.facebook.presto.cost.VariableStatsEstimate; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -93,13 +93,13 @@ private void testDetermineDistributionType(JoinDistributionType sessionDistribut p.join( joinType, p.values( - ImmutableList.of(p.symbol("A1")), + ImmutableList.of(p.variable(p.symbol("A1"))), ImmutableList.of(constantExpressions(BIGINT, 10), constantExpressions(BIGINT, 11))), p.values( - ImmutableList.of(p.symbol("B1")), + ImmutableList.of(p.variable(p.symbol("B1"))), ImmutableList.of(constantExpressions(BIGINT, 50), constantExpressions(BIGINT, 11))), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable(p.symbol("A1", BIGINT)), p.variable(p.symbol("B1", BIGINT)))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, sessionDistributedJoin.name()) .matches(join( @@ -129,13 +129,13 @@ private void testRepartitionRightOuter(JoinDistributionType sessionDistributedJo p.join( joinType, p.values( - ImmutableList.of(p.symbol("A1")), + ImmutableList.of(p.variable(p.symbol("A1"))), ImmutableList.of(constantExpressions(BIGINT, 10), constantExpressions(BIGINT, 11))), p.values( - ImmutableList.of(p.symbol("B1")), + ImmutableList.of(p.variable(p.symbol("B1"))), ImmutableList.of(constantExpressions(BIGINT, 50), constantExpressions(BIGINT, 11))), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable(p.symbol("A1", BIGINT)), p.variable(p.symbol("B1", BIGINT)))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, sessionDistributedJoin.name()) .matches(join( @@ -155,14 +155,14 @@ public void testReplicateScalar() p.join( INNER, p.values( - ImmutableList.of(p.symbol("A1")), + ImmutableList.of(p.variable(p.symbol("A1"))), ImmutableList.of(constantExpressions(BIGINT, 10), constantExpressions(BIGINT, 11))), p.enforceSingleRow( p.values( - ImmutableList.of(p.symbol("B1")), + ImmutableList.of(p.variable(p.symbol("B1"))), ImmutableList.of(constantExpressions(BIGINT, 50), constantExpressions(BIGINT, 11)))), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable(p.symbol("A1", BIGINT)), p.variable(p.symbol("B1", BIGINT)))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.PARTITIONED.name()) .matches(join( @@ -188,13 +188,13 @@ private void testReplicateNoEquiCriteria(Type joinType) p.join( joinType, p.values( - ImmutableList.of(p.symbol("A1")), + ImmutableList.of(p.variable(p.symbol("A1"))), ImmutableList.of(constantExpressions(BIGINT, 10), constantExpressions(BIGINT, 11))), p.values( - ImmutableList.of(p.symbol("B1")), + ImmutableList.of(p.variable(p.symbol("B1"))), ImmutableList.of(constantExpressions(BIGINT, 50), constantExpressions(BIGINT, 11))), ImmutableList.of(), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.of(expression("A1 * B1 > 100")))) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.PARTITIONED.name()) .matches(join( @@ -214,13 +214,13 @@ public void testRetainDistributionType() p.join( INNER, p.values( - ImmutableList.of(p.symbol("A1")), + ImmutableList.of(p.variable(p.symbol("A1"))), ImmutableList.of(constantExpressions(BIGINT, 10), constantExpressions(BIGINT, 11))), p.values( - ImmutableList.of(p.symbol("B1")), + ImmutableList.of(p.variable(p.symbol("B1"))), ImmutableList.of(constantExpressions(BIGINT, 50), constantExpressions(BIGINT, 11))), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable(p.symbol("A1", BIGINT)), p.variable(p.symbol("B1", BIGINT)))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty(), Optional.empty(), Optional.empty(), @@ -237,19 +237,19 @@ public void testFlipAndReplicateWhenOneTableMuchSmaller() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( INNER, @@ -270,20 +270,20 @@ public void testFlipAndReplicateWhenOneTableMuchSmallerAndJoinCardinalityUnknown .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) // set symbol stats to unknown, so the join cardinality cannot be estimated - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), VariableStatsEstimate.unknown())) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) // set symbol stats to unknown, so the join cardinality cannot be estimated - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), VariableStatsEstimate.unknown())) .build()) .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( INNER, @@ -303,19 +303,19 @@ public void testPartitionWhenRequiredBySession() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.PARTITIONED.name()) .matches(join( @@ -336,19 +336,19 @@ public void testPartitionWhenBothTablesEqual() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( INNER, @@ -368,19 +368,19 @@ public void testReplicatesWhenRequiredBySession() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.BROADCAST.name()) .matches(join( @@ -401,19 +401,19 @@ public void testPartitionFullOuterJoin() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( FULL, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( FULL, @@ -433,19 +433,19 @@ public void testPartitionRightOuterJoin() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( RIGHT, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( RIGHT, @@ -465,19 +465,19 @@ public void testReplicateLeftOuterJoin() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( LEFT, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( LEFT, @@ -497,19 +497,19 @@ public void testFlipAndReplicateRightOuterJoin() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( RIGHT, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( LEFT, @@ -530,20 +530,20 @@ public void testFlipAndReplicateRightOuterJoinWhenJoinCardinalityUnknown() .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) // set symbol stats to unknown, so the join cardinality cannot be estimated - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), VariableStatsEstimate.unknown())) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) // set symbol stats to unknown, so the join cardinality cannot be estimated - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), VariableStatsEstimate.unknown())) .build()) .on(p -> p.join( RIGHT, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( LEFT, @@ -562,11 +562,11 @@ public void testReplicatesWhenNotRestricted() PlanNodeStatsEstimate probeSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 10))) .build(); PlanNodeStatsEstimate buildSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 10))) .build(); // B table is small enough to be replicated in AUTOMATIC_RESTRICTED mode @@ -578,10 +578,10 @@ public void testReplicatesWhenNotRestricted() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( INNER, @@ -593,11 +593,11 @@ public void testReplicatesWhenNotRestricted() probeSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); buildSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); // B table exceeds AUTOMATIC_RESTRICTED limit therefore it is partitioned @@ -609,10 +609,10 @@ public void testReplicatesWhenNotRestricted() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("A1", BIGINT), p.variable("B1", BIGINT))), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( INNER, @@ -632,19 +632,19 @@ public void testChoosesLeftWhenCriteriaEmpty() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( RIGHT, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), ImmutableList.of(), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( LEFT, @@ -665,19 +665,19 @@ public void testChoosesRightWhenFallsBackToSyntactic() .setSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, "100MB") .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.join( RIGHT, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), ImmutableList.of(), - ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + ImmutableList.of(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), Optional.empty())) .matches(join( RIGHT, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineSemiJoinDistributionType.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineSemiJoinDistributionType.java index 1641480acd8ca..c9f4592824a41 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineSemiJoinDistributionType.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineSemiJoinDistributionType.java @@ -29,11 +29,11 @@ import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.PlanNodeStatsEstimate; -import com.facebook.presto.cost.SymbolStatsEstimate; import com.facebook.presto.cost.TaskCountEstimator; +import com.facebook.presto.cost.VariableStatsEstimate; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.google.common.collect.ImmutableList; @@ -81,14 +81,14 @@ public void testRetainDistributionType() .on(p -> p.semiJoin( p.values( - ImmutableList.of(p.symbol("A1")), + ImmutableList.of(p.variable(p.symbol("A1"))), ImmutableList.of(constantExpressions(BIGINT, 10), constantExpressions(BIGINT, 11))), p.values( - ImmutableList.of(p.symbol("B1")), + ImmutableList.of(p.variable(p.symbol("B1"))), ImmutableList.of(constantExpressions(BIGINT, 50), constantExpressions(BIGINT, 11))), - p.symbol("A1"), - p.symbol("B1"), - p.symbol("output"), + p.variable("A1"), + p.variable("B1"), + p.variable("output"), Optional.empty(), Optional.empty(), Optional.of(REPLICATED))) @@ -104,19 +104,19 @@ public void testPartitionWhenRequiredBySession() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.PARTITIONED.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .on(p -> p.semiJoin( - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - p.symbol("A1"), - p.symbol("B1"), - p.symbol("output"), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + p.variable("A1"), + p.variable("B1"), + p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())) @@ -139,19 +139,19 @@ public void testReplicatesWhenRequiredBySession() .setSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, "1B") .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), VariableStatsEstimate.unknown())) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), VariableStatsEstimate.unknown())) .build()) .on(p -> p.semiJoin( - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - p.symbol("A1"), - p.symbol("B1"), - p.symbol("output"), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + p.variable("A1"), + p.variable("B1"), + p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())) @@ -173,19 +173,19 @@ public void testPartitionsWhenBothTablesEqual() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), VariableStatsEstimate.unknown())) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), VariableStatsEstimate.unknown())) .build()) .on(p -> p.semiJoin( - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - p.symbol("A1"), - p.symbol("B1"), - p.symbol("output"), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + p.variable("A1"), + p.variable("B1"), + p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())) @@ -207,19 +207,19 @@ public void testReplicatesWhenFilterMuchSmaller() .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), VariableStatsEstimate.unknown())) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), SymbolStatsEstimate.unknown())) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), VariableStatsEstimate.unknown())) .build()) .on(p -> p.semiJoin( - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - p.symbol("A1"), - p.symbol("B1"), - p.symbol("output"), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + p.variable("A1"), + p.variable("B1"), + p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())) @@ -240,11 +240,11 @@ public void testReplicatesWhenNotRestricted() PlanNodeStatsEstimate probeSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 10))) .build(); PlanNodeStatsEstimate buildSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 10))) .build(); // B table is small enough to be replicated in AUTOMATIC_RESTRICTED mode @@ -255,11 +255,11 @@ public void testReplicatesWhenNotRestricted() .overrideStats("valuesB", buildSideStatsEstimate) .on(p -> p.semiJoin( - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - p.symbol("A1"), - p.symbol("B1"), - p.symbol("output"), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + p.variable("A1"), + p.variable("B1"), + p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())) @@ -273,11 +273,11 @@ public void testReplicatesWhenNotRestricted() probeSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); buildSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); // B table exceeds AUTOMATIC_RESTRICTED limit therefore it is partitioned @@ -288,11 +288,11 @@ public void testReplicatesWhenNotRestricted() .overrideStats("valuesB", buildSideStatsEstimate) .on(p -> p.semiJoin( - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", BIGINT)), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", BIGINT)), - p.symbol("A1"), - p.symbol("B1"), - p.symbol("output"), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", BIGINT)), + p.variable("A1"), + p.variable("B1"), + p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java index 262074f4a0bea..f3c790a1bdba0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; @@ -36,6 +36,7 @@ import java.util.function.Function; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; @@ -44,7 +45,6 @@ import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.tree.ArithmeticUnaryExpression.Sign.MINUS; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -64,9 +64,9 @@ public void testEliminateCrossJoin() .on(crossJoinAndJoin(INNER)) .matches( join(INNER, - ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("cySymbol"), new Symbol("bySymbol"))), + ImmutableList.of(aliases -> new EquiJoinClause(variable("cyVariable"), variable("byVariable"))), join(INNER, - ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("axSymbol"), new Symbol("cxSymbol"))), + ImmutableList.of(aliases -> new EquiJoinClause(variable("axVariable"), variable("cxVariable"))), any(), any()), any())); @@ -108,11 +108,11 @@ public void testJoinOrder() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b"))), - values(symbol("c")), - symbol("a"), symbol("c"), - symbol("c"), symbol("b")); + values(variable("a")), + values(variable("b"))), + values(variable("c")), + variable("a"), variable("c"), + variable("c"), variable("b")); JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); @@ -127,20 +127,20 @@ public void testJoinOrderWithRealCrossJoin() PlanNode leftPlan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b"))), - values(symbol("c")), - symbol("a"), symbol("c"), - symbol("c"), symbol("b")); + values(variable("a")), + values(variable("b"))), + values(variable("c")), + variable("a"), variable("c"), + variable("c"), variable("b")); PlanNode rightPlan = joinNode( joinNode( - values(symbol("x")), - values(symbol("y"))), - values(symbol("z")), - symbol("x"), symbol("z"), - symbol("z"), symbol("y")); + values(variable("x")), + values(variable("y"))), + values(variable("z")), + variable("x"), variable("z"), + variable("z"), variable("y")); PlanNode plan = joinNode(leftPlan, rightPlan); @@ -157,12 +157,12 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b1"), symbol("b2"))), - values(symbol("c1"), symbol("c2")), - symbol("a"), symbol("c1"), - symbol("c1"), symbol("b1"), - symbol("c2"), symbol("b2")); + values(variable("a")), + values(variable("b1"), variable("b2"))), + values(variable("c1"), variable("c2")), + variable("a"), variable("c1"), + variable("c1"), variable("b1"), + variable("c2"), variable("b2")); JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); @@ -177,11 +177,11 @@ public void testDonNotChangeOrderWithoutCrossJoin() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b")), - symbol("a"), symbol("b")), - values(symbol("c")), - symbol("c"), symbol("b")); + values(variable("a")), + values(variable("b")), + variable("a"), variable("b")), + values(variable("c")), + variable("c"), variable("b")); JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); @@ -196,10 +196,10 @@ public void testDoNotReorderCrossJoins() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b"))), - values(symbol("c")), - symbol("c"), symbol("b")); + values(variable("a")), + values(variable("b"))), + values(variable("c")), + variable("c"), variable("b")); JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); @@ -215,13 +215,13 @@ public void testGiveUpOnNonIdentityProjections() joinNode( projectNode( joinNode( - values(symbol("a1")), - values(symbol("b"))), - symbol("a2"), + values(variable("a1")), + values(variable("b"))), + variable("a2"), new ArithmeticUnaryExpression(MINUS, new SymbolReference("a1"))), - values(symbol("c")), - symbol("a2"), symbol("c"), - symbol("c"), symbol("b")); + values(variable("c")), + variable("a2"), variable("c"), + variable("c"), variable("b")); assertEquals(JoinGraph.buildFrom(plan).size(), 2); } @@ -229,42 +229,42 @@ public void testGiveUpOnNonIdentityProjections() private Function crossJoinAndJoin(JoinNode.Type secondJoinType) { return p -> { - Symbol axSymbol = p.symbol("axSymbol"); - Symbol bySymbol = p.symbol("bySymbol"); - Symbol cxSymbol = p.symbol("cxSymbol"); - Symbol cySymbol = p.symbol("cySymbol"); + VariableReferenceExpression axVariable = p.variable("axVariable"); + VariableReferenceExpression byVariable = p.variable("byVariable"); + VariableReferenceExpression cxVariable = p.variable("cxVariable"); + VariableReferenceExpression cyVariable = p.variable("cyVariable"); // (a inner join b) inner join c on c.x = a.x and c.y = b.y return p.join(INNER, p.join(secondJoinType, - p.values(axSymbol), - p.values(bySymbol)), - p.values(cxSymbol, cySymbol), - new EquiJoinClause(cxSymbol, axSymbol), - new EquiJoinClause(cySymbol, bySymbol)); + p.values(axVariable), + p.values(byVariable)), + p.values(cxVariable, cyVariable), + new EquiJoinClause(p.variable(cxVariable), p.variable(axVariable)), + new EquiJoinClause(p.variable(cyVariable), p.variable(byVariable))); }; } - private PlanNode projectNode(PlanNode source, String symbol, Expression expression) + private PlanNode projectNode(PlanNode source, VariableReferenceExpression variable, Expression expression) { return new ProjectNode( idAllocator.getNextId(), source, - Assignments.of(new Symbol(symbol), expression)); + Assignments.of(variable, expression)); } - private String symbol(String name) + private VariableReferenceExpression variable(String name) { - return name; + return new VariableReferenceExpression(name, BIGINT); } - private JoinNode joinNode(PlanNode left, PlanNode right, String... symbols) + private JoinNode joinNode(PlanNode left, PlanNode right, VariableReferenceExpression... variables) { - checkArgument(symbols.length % 2 == 0); + checkArgument(variables.length % 2 == 0); ImmutableList.Builder criteria = ImmutableList.builder(); - for (int i = 0; i < symbols.length; i += 2) { - criteria.add(new JoinNode.EquiJoinClause(new Symbol(symbols[i]), new Symbol(symbols[i + 1]))); + for (int i = 0; i < variables.length; i += 2) { + criteria.add(new JoinNode.EquiJoinClause(variables[i], variables[i + 1])); } return new JoinNode( @@ -273,9 +273,9 @@ private JoinNode joinNode(PlanNode left, PlanNode right, String... symbols) left, right, criteria.build(), - ImmutableList.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + ImmutableList.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(), Optional.empty(), Optional.empty(), @@ -283,11 +283,11 @@ private JoinNode joinNode(PlanNode left, PlanNode right, String... symbols) Optional.empty()); } - private ValuesNode values(String... symbols) + private ValuesNode values(VariableReferenceExpression... variables) { return new ValuesNode( idAllocator.getNextId(), - Arrays.stream(symbols).map(Symbol::new).collect(toImmutableList()), + Arrays.asList(variables), ImmutableList.of()); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java index 6ef69ff517986..dc53bf797db2f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java @@ -33,7 +33,7 @@ public void testDoesNotFire() .on(p -> p.limit( 1, - p.values(p.symbol("a")))) + p.values(p.variable("a")))) .doesNotFire(); } @@ -47,7 +47,7 @@ public void test() p.filter( expression("b > 5"), p.values( - ImmutableList.of(p.symbol("a"), p.symbol("b")), + ImmutableList.of(p.variable(p.symbol("a")), p.variable(p.symbol("b"))), ImmutableList.of( constantExpressions(BIGINT, 1, 10), constantExpressions(BIGINT, 2, 11)))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java index 20e1f8603378a..054784c90e0ef 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java @@ -35,7 +35,7 @@ public void testDoesNotFire() p.sample( 0.15, Type.BERNOULLI, - p.values(p.symbol("a")))) + p.values(p.variable("a")))) .doesNotFire(); } @@ -50,7 +50,7 @@ public void test() p.filter( expression("b > 5"), p.values( - ImmutableList.of(p.symbol("a"), p.symbol("b")), + ImmutableList.of(p.variable("a"), p.variable("b")), ImmutableList.of( constantExpressions(BIGINT, 1, 10), constantExpressions(BIGINT, 2, 11)))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java index 7a06477b969c1..08c0bc72463ac 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.type.DateType; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; @@ -59,8 +58,8 @@ public void testProjectionExpressionRewrite() { tester().assertThat(zeroRewriter.projectExpressionRewrite()) .on(p -> p.project( - Assignments.of(p.symbol("y"), PlanBuilder.expression("x IS NOT NULL")), - p.values(p.symbol("x")))) + Assignments.of(p.variable("y"), PlanBuilder.expression("x IS NOT NULL")), + p.values(p.variable("x")))) .matches( project(ImmutableMap.of("y", expression("0")), values("x"))); } @@ -70,8 +69,8 @@ public void testProjectionExpressionNotRewritten() { tester().assertThat(zeroRewriter.projectExpressionRewrite()) .on(p -> p.project( - Assignments.of(p.symbol("y"), PlanBuilder.expression("0")), - p.values(p.symbol("x")))) + Assignments.of(p.variable("y"), PlanBuilder.expression("0")), + p.values(p.variable("x")))) .doesNotFire(); } @@ -82,11 +81,11 @@ public void testAggregationExpressionRewrite() .on(p -> p.aggregation(a -> a .globalGrouping() .addAggregation( - p.symbol("count_1", BIGINT), + p.variable(p.symbol("count_1", BIGINT)), new FunctionCall(QualifiedName.of("count"), ImmutableList.of(p.symbol("y", BIGINT).toSymbolReference())), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("x", BIGINT))))) + p.values(p.variable("x", BIGINT))))) .matches( PlanMatchPattern.aggregation( ImmutableMap.of("count_1", functionCall("count", ImmutableList.of("x"))), @@ -101,7 +100,7 @@ public void testAggregationExpressionNotRewritten() .on(p -> p.aggregation(a -> a .globalGrouping() .addAggregation( - p.symbol("count_1", DateType.DATE), + p.variable(p.symbol("count_1", DateType.DATE)), nowCall, ImmutableList.of()) .source( @@ -112,7 +111,7 @@ public void testAggregationExpressionNotRewritten() .on(p -> p.aggregation(a -> a .globalGrouping() .addAggregation( - p.symbol("count_1", BIGINT), + p.variable(p.symbol("count_1", BIGINT)), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BIGINT)) .source( @@ -142,7 +141,7 @@ public void testValueExpressionRewrite() { tester().assertThat(zeroRewriter.valuesExpressionRewrite()) .on(p -> p.values( - ImmutableList.of(p.symbol("a")), + ImmutableList.of(p.variable(p.symbol("a"))), ImmutableList.of((ImmutableList.of(castToRowExpression(PlanBuilder.expression("1"))))))) .matches( values(ImmutableList.of("a"), ImmutableList.of(ImmutableList.of(new LongLiteral("0"))))); @@ -153,7 +152,7 @@ public void testValueExpressionNotRewritten() { tester().assertThat(zeroRewriter.valuesExpressionRewrite()) .on(p -> p.values( - ImmutableList.of(p.symbol("a")), + ImmutableList.of(p.variable(p.symbol("a"))), ImmutableList.of((ImmutableList.of(castToRowExpression(PlanBuilder.expression("0"))))))) .doesNotFire(); } @@ -164,7 +163,7 @@ public void testApplyExpressionRewrite() tester().assertThat(applyRewriter.applyExpressionRewrite()) .on(p -> p.apply( Assignments.of( - p.symbol("a", BIGINT), + p.variable("a", BIGINT), new InPredicate( new LongLiteral("1"), new InListExpression(ImmutableList.of( @@ -187,7 +186,7 @@ public void testApplyExpressionNotRewritten() tester().assertThat(applyRewriter.applyExpressionRewrite()) .on(p -> p.apply( Assignments.of( - p.symbol("a", BIGINT), + p.variable("a", BIGINT), new InPredicate( new LongLiteral("0"), new InListExpression(ImmutableList.of( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java index efd8fcf6c5b20..10cf25e2c433b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java @@ -34,21 +34,21 @@ public void test() .on(p -> p.project( Assignments.builder() - .put(p.symbol("identity"), expression("symbol")) // identity - .put(p.symbol("multi_complex_1"), expression("complex + 1")) // complex expression referenced multiple times - .put(p.symbol("multi_complex_2"), expression("complex + 2")) // complex expression referenced multiple times - .put(p.symbol("multi_literal_1"), expression("literal + 1")) // literal referenced multiple times - .put(p.symbol("multi_literal_2"), expression("literal + 2")) // literal referenced multiple times - .put(p.symbol("single_complex"), expression("complex_2 + 2")) // complex expression reference only once - .put(p.symbol("try"), expression("try(complex / literal)")) + .put(p.variable("identity"), expression("symbol")) // identity + .put(p.variable("multi_complex_1"), expression("complex + 1")) // complex expression referenced multiple times + .put(p.variable("multi_complex_2"), expression("complex + 2")) // complex expression referenced multiple times + .put(p.variable("multi_literal_1"), expression("literal + 1")) // literal referenced multiple times + .put(p.variable("multi_literal_2"), expression("literal + 2")) // literal referenced multiple times + .put(p.variable("single_complex"), expression("complex_2 + 2")) // complex expression reference only once + .put(p.variable("try"), expression("try(complex / literal)")) .build(), p.project(Assignments.builder() - .put(p.symbol("symbol"), expression("x")) - .put(p.symbol("complex"), expression("x * 2")) - .put(p.symbol("literal"), expression("1")) - .put(p.symbol("complex_2"), expression("x - 1")) + .put(p.variable("symbol"), expression("x")) + .put(p.variable("complex"), expression("x * 2")) + .put(p.variable("literal"), expression("1")) + .put(p.variable("complex_2"), expression("x - 1")) .build(), - p.values(p.symbol("x"))))) + p.values(p.variable("x"))))) .matches( project( ImmutableMap.builder() @@ -73,10 +73,10 @@ public void testIdentityProjections() tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.of(p.symbol("output"), expression("value")), + Assignments.of(p.variable("output"), expression("value")), p.project( - Assignments.identity(p.symbol("value")), - p.values(p.symbol("value"))))) + Assignments.identity(p.variable("value")), + p.values(p.variable("value"))))) .doesNotFire(); } @@ -86,10 +86,10 @@ public void testSubqueryProjections() tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.identity(p.symbol("fromOuterScope"), p.symbol("value")), + Assignments.identity(p.variable("fromOuterScope"), p.variable("value")), p.project( - Assignments.identity(p.symbol("value")), - p.values(p.symbol("value"))))) + Assignments.identity(p.variable("value")), + p.values(p.variable("value"))))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java index 2549bd96a63b7..d39b6865d17e4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -23,7 +23,7 @@ import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; @@ -91,8 +91,8 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() { PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); PlanBuilder p = new PlanBuilder(idAllocator, queryRunner.getMetadata()); - Symbol a1 = p.symbol("A1"); - Symbol b1 = p.symbol("B1"); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); MultiJoinNode multiJoinNode = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(p.values(a1), p.values(b1))), TRUE_LITERAL, @@ -101,7 +101,7 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() new CostComparator(1, 1, 1), multiJoinNode.getFilter(), createContext()); - JoinEnumerationResult actual = joinEnumerator.createJoinAccordingToPartitioning(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols(), ImmutableSet.of(0)); + JoinEnumerationResult actual = joinEnumerator.createJoinAccordingToPartitioning(multiJoinNode.getSources(), multiJoinNode.getOutputVariables(), ImmutableSet.of(0)); assertFalse(actual.getPlanNode().isPresent()); assertEquals(actual.getCost(), PlanCostEstimate.infinite()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java index 30e01ea6cdf9f..5e72f535ce014 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java @@ -15,7 +15,7 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.MultiJoinNode; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -25,6 +25,7 @@ import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; import org.testng.annotations.AfterClass; @@ -72,8 +73,8 @@ public void tearDown() public void testDoesNotAllowOuterJoin() { PlanBuilder p = planBuilder(); - Symbol a1 = p.symbol("A1"); - Symbol b1 = p.symbol("B1"); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); JoinNode outerJoin = p.join( FULL, p.values(a1), @@ -88,9 +89,9 @@ public void testDoesNotAllowOuterJoin() public void testDoesNotConvertNestedOuterJoins() { PlanBuilder p = planBuilder(); - Symbol a1 = p.symbol("A1"); - Symbol b1 = p.symbol("B1"); - Symbol c1 = p.symbol("C1"); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); + VariableReferenceExpression c1 = p.variable("C1"); JoinNode leftJoin = p.join( LEFT, p.values(a1), @@ -109,7 +110,7 @@ public void testDoesNotConvertNestedOuterJoins() MultiJoinNode expected = MultiJoinNode.builder() .setSources(leftJoin, valuesC).setFilter(createEqualsExpression(a1, c1)) - .setOutputSymbols(a1, b1, c1) + .setOutputVariables(a1, b1, c1) .build(); assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT), expected); } @@ -118,11 +119,11 @@ public void testDoesNotConvertNestedOuterJoins() public void testRetainsOutputSymbols() { PlanBuilder p = planBuilder(); - Symbol a1 = p.symbol("A1"); - Symbol b1 = p.symbol("B1"); - Symbol b2 = p.symbol("B2"); - Symbol c1 = p.symbol("C1"); - Symbol c2 = p.symbol("C2"); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); + VariableReferenceExpression b2 = p.variable("B2"); + VariableReferenceExpression c1 = p.variable("C1"); + VariableReferenceExpression c2 = p.variable("C2"); ValuesNode valuesA = p.values(a1); ValuesNode valuesB = p.values(b1, b2); ValuesNode valuesC = p.values(c1, c2); @@ -134,11 +135,7 @@ public void testRetainsOutputSymbols() valuesB, valuesC, ImmutableList.of(equiJoinClause(b1, c1)), - ImmutableList.of( - b1, - b2, - c1, - c2), + ImmutableList.of(b1, b2, c1, c2), Optional.empty()), ImmutableList.of(equiJoinClause(a1, b1)), ImmutableList.of(a1, b1), @@ -146,7 +143,7 @@ public void testRetainsOutputSymbols() MultiJoinNode expected = MultiJoinNode.builder() .setSources(valuesA, valuesB, valuesC) .setFilter(and(createEqualsExpression(b1, c1), createEqualsExpression(a1, b1))) - .setOutputSymbols(a1, b1) + .setOutputVariables(a1, b1) .build(); assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT), expected); } @@ -155,22 +152,22 @@ public void testRetainsOutputSymbols() public void testCombinesCriteriaAndFilters() { PlanBuilder p = planBuilder(); - Symbol a1 = p.symbol("A1"); - Symbol b1 = p.symbol("B1"); - Symbol b2 = p.symbol("B2"); - Symbol c1 = p.symbol("C1"); - Symbol c2 = p.symbol("C2"); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); + VariableReferenceExpression b2 = p.variable("B2"); + VariableReferenceExpression c1 = p.variable("C1"); + VariableReferenceExpression c2 = p.variable("C2"); ValuesNode valuesA = p.values(a1); ValuesNode valuesB = p.values(b1, b2); ValuesNode valuesC = p.values(c1, c2); Expression bcFilter = and( - new ComparisonExpression(GREATER_THAN, c2.toSymbolReference(), new LongLiteral("0")), - new ComparisonExpression(NOT_EQUAL, c2.toSymbolReference(), new LongLiteral("7")), - new ComparisonExpression(GREATER_THAN, b2.toSymbolReference(), c2.toSymbolReference())); + new ComparisonExpression(GREATER_THAN, new SymbolReference(c2.getName()), new LongLiteral("0")), + new ComparisonExpression(NOT_EQUAL, new SymbolReference(c2.getName()), new LongLiteral("7")), + new ComparisonExpression(GREATER_THAN, new SymbolReference(b2.getName()), new SymbolReference(c2.getName()))); ComparisonExpression abcFilter = new ComparisonExpression( LESS_THAN, - new ArithmeticBinaryExpression(ADD, a1.toSymbolReference(), c1.toSymbolReference()), - b1.toSymbolReference()); + new ArithmeticBinaryExpression(ADD, new SymbolReference(a1.getName()), new SymbolReference(c1.getName())), + new SymbolReference(b1.getName())); JoinNode joinNode = p.join( INNER, valuesA, @@ -179,18 +176,14 @@ public void testCombinesCriteriaAndFilters() valuesB, valuesC, ImmutableList.of(equiJoinClause(b1, c1)), - ImmutableList.of( - b1, - b2, - c1, - c2), + ImmutableList.of(b1, b2, c1, c2), Optional.of(bcFilter)), ImmutableList.of(equiJoinClause(a1, b1)), ImmutableList.of(a1, b1, b2, c1, c2), Optional.of(abcFilter)); MultiJoinNode expected = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(valuesA, valuesB, valuesC)), - and(new ComparisonExpression(EQUAL, b1.toSymbolReference(), c1.toSymbolReference()), new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference()), bcFilter, abcFilter), + and(new ComparisonExpression(EQUAL, new SymbolReference(b1.getName()), new SymbolReference(c1.getName())), new ComparisonExpression(EQUAL, new SymbolReference(a1.getName()), new SymbolReference(b1.getName())), bcFilter, abcFilter), ImmutableList.of(a1, b1, b2, c1, c2)); assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT), expected); } @@ -199,13 +192,13 @@ public void testCombinesCriteriaAndFilters() public void testConvertsBushyTrees() { PlanBuilder p = planBuilder(); - Symbol a1 = p.symbol("A1"); - Symbol b1 = p.symbol("B1"); - Symbol c1 = p.symbol("C1"); - Symbol d1 = p.symbol("D1"); - Symbol d2 = p.symbol("D2"); - Symbol e1 = p.symbol("E1"); - Symbol e2 = p.symbol("E2"); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); + VariableReferenceExpression c1 = p.variable("C1"); + VariableReferenceExpression d1 = p.variable("D1"); + VariableReferenceExpression d2 = p.variable("D2"); + VariableReferenceExpression e1 = p.variable("E1"); + VariableReferenceExpression e2 = p.variable("E2"); ValuesNode valuesA = p.values(a1); ValuesNode valuesB = p.values(b1); ValuesNode valuesC = p.values(c1); @@ -233,26 +226,15 @@ public void testConvertsBushyTrees() ImmutableList.of( equiJoinClause(d1, e1), equiJoinClause(d2, e2)), - ImmutableList.of( - d1, - d2, - e1, - e2), + ImmutableList.of(d1, d2, e1, e2), Optional.empty()), ImmutableList.of(equiJoinClause(b1, e1)), - ImmutableList.of( - a1, - b1, - c1, - d1, - d2, - e1, - e2), + ImmutableList.of(a1, b1, c1, d1, d2, e1, e2), Optional.empty()); MultiJoinNode expected = MultiJoinNode.builder() .setSources(valuesA, valuesB, valuesC, valuesD, valuesE) .setFilter(and(createEqualsExpression(a1, b1), createEqualsExpression(a1, c1), createEqualsExpression(d1, e1), createEqualsExpression(d2, e2), createEqualsExpression(b1, e1))) - .setOutputSymbols(a1, b1, c1, d1, d2, e1, e2) + .setOutputVariables(a1, b1, c1, d1, d2, e1, e2) .build(); assertEquals(toMultiJoinNode(joinNode, noLookup(), 5), expected); } @@ -261,13 +243,13 @@ public void testConvertsBushyTrees() public void testMoreThanJoinLimit() { PlanBuilder p = planBuilder(); - Symbol a1 = p.symbol("A1"); - Symbol b1 = p.symbol("B1"); - Symbol c1 = p.symbol("C1"); - Symbol d1 = p.symbol("D1"); - Symbol d2 = p.symbol("D2"); - Symbol e1 = p.symbol("E1"); - Symbol e2 = p.symbol("E2"); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); + VariableReferenceExpression c1 = p.variable("C1"); + VariableReferenceExpression d1 = p.variable("D1"); + VariableReferenceExpression d2 = p.variable("D2"); + VariableReferenceExpression e1 = p.variable("E1"); + VariableReferenceExpression e2 = p.variable("E2"); ValuesNode valuesA = p.values(a1); ValuesNode valuesB = p.values(b1); ValuesNode valuesC = p.values(c1); @@ -287,11 +269,7 @@ public void testMoreThanJoinLimit() ImmutableList.of( equiJoinClause(d1, e1), equiJoinClause(d2, e2)), - ImmutableList.of( - d1, - d2, - e1, - e2), + ImmutableList.of(d1, d2, e1, e2), Optional.empty()); JoinNode joinNode = p.join( INNER, @@ -304,31 +282,24 @@ public void testMoreThanJoinLimit() Optional.empty()), join2, ImmutableList.of(equiJoinClause(b1, e1)), - ImmutableList.of( - a1, - b1, - c1, - d1, - d2, - e1, - e2), + ImmutableList.of(a1, b1, c1, d1, d2, e1, e2), Optional.empty()); MultiJoinNode expected = MultiJoinNode.builder() .setSources(join1, join2, valuesC) .setFilter(and(createEqualsExpression(a1, c1), createEqualsExpression(b1, e1))) - .setOutputSymbols(a1, b1, c1, d1, d2, e1, e2) + .setOutputVariables(a1, b1, c1, d1, d2, e1, e2) .build(); assertEquals(toMultiJoinNode(joinNode, noLookup(), 2), expected); } - private ComparisonExpression createEqualsExpression(Symbol left, Symbol right) + private ComparisonExpression createEqualsExpression(VariableReferenceExpression left, VariableReferenceExpression right) { - return new ComparisonExpression(EQUAL, left.toSymbolReference(), right.toSymbolReference()); + return new ComparisonExpression(EQUAL, new SymbolReference(left.getName()), new SymbolReference(right.getName())); } - private EquiJoinClause equiJoinClause(Symbol symbol1, Symbol symbol2) + private EquiJoinClause equiJoinClause(VariableReferenceExpression variable1, VariableReferenceExpression variable2) { - return new EquiJoinClause(symbol1, symbol2); + return new EquiJoinClause(variable1, variable2); } private PlanBuilder planBuilder() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java index acc192a98e572..88917b0f9f617 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLambdaCaptureDesugaringRewriter.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; @@ -28,6 +27,7 @@ import java.util.Map; import java.util.stream.Stream; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.iterative.rule.LambdaCaptureDesugaringRewriter.rewrite; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static java.util.stream.Collectors.toList; @@ -38,10 +38,10 @@ public class TestLambdaCaptureDesugaringRewriter @Test public void testRewriteBasicLambda() { - final Map symbols = ImmutableMap.of(new Symbol("a"), BigintType.BIGINT); + final Map symbols = ImmutableMap.of(new Symbol("a"), BIGINT, new Symbol("x"), BIGINT); final SymbolAllocator allocator = new SymbolAllocator(symbols); - assertEquals(rewrite(expression("x -> a + x"), allocator.getTypes(), allocator), + assertEquals(rewrite(expression("x -> a + x"), allocator), new BindExpression( ImmutableList.of(expression("a")), new LambdaExpression( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index ae64b236e1159..72b0e87ee33e8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -67,7 +67,7 @@ public class TestMergeAdjacentWindows public void testPlanWithoutWindowNode() { tester().assertThat(new GatherAndMergeWindows.MergeAdjacentWindowsOverProjects(0)) - .on(p -> p.values(p.symbol("a"))) + .on(p -> p.values(p.variable("a"))) .doesNotFire(); } @@ -78,8 +78,8 @@ public void testPlanWithSingleWindowNode() .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), - p.values(p.symbol("a")))) + ImmutableMap.of(p.variable(p.symbol("avg_1")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), + p.values(p.variable("a")))) .doesNotFire(); } @@ -90,11 +90,11 @@ public void testDistinctAdjacentWindowSpecifications() .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), + ImmutableMap.of(p.variable(p.symbol("avg_1")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), p.window( newWindowNodeSpecification(p, "b"), - ImmutableMap.of(p.symbol("sum_1"), newWindowNodeFunction("sum", SUM_FUNCTION_HANDLE, "b")), - p.values(p.symbol("b"))))) + ImmutableMap.of(p.variable(p.symbol("sum_1")), newWindowNodeFunction("sum", SUM_FUNCTION_HANDLE, "b")), + p.values(p.variable("b"))))) .doesNotFire(); } @@ -105,13 +105,13 @@ public void testIntermediateNonProjectNode() .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_2"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), + ImmutableMap.of(p.variable(p.symbol("avg_2")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), p.filter( expression("a > 5"), p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), - p.values(p.symbol("a")))))) + ImmutableMap.of(p.variable(p.symbol("avg_1")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), + p.values(p.variable("a")))))) .doesNotFire(); } @@ -122,11 +122,11 @@ public void testDependentAdjacentWindowsIdenticalSpecifications() .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "avg_2")), + ImmutableMap.of(p.variable(p.symbol("avg_1")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "avg_2")), p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_2"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), - p.values(p.symbol("a"))))) + ImmutableMap.of(p.variable(p.symbol("avg_2")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), + p.values(p.variable("a"))))) .doesNotFire(); } @@ -137,11 +137,11 @@ public void testDependentAdjacentWindowsDistinctSpecifications() .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "avg_2")), + ImmutableMap.of(p.variable(p.symbol("avg_1")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "avg_2")), p.window( newWindowNodeSpecification(p, "b"), - ImmutableMap.of(p.symbol("avg_2"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), - p.values(p.symbol("a"), p.symbol("b"))))) + ImmutableMap.of(p.variable(p.symbol("avg_2")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), + p.values(p.variable("a"), p.variable("b"))))) .doesNotFire(); } @@ -152,11 +152,11 @@ public void testIdenticalAdjacentWindowSpecifications() .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), + ImmutableMap.of(p.variable(p.symbol("avg_1")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("sum_1"), newWindowNodeFunction("sum", SUM_FUNCTION_HANDLE, "a")), - p.values(p.symbol("a"))))) + ImmutableMap.of(p.variable(p.symbol("sum_1")), newWindowNodeFunction("sum", SUM_FUNCTION_HANDLE, "a")), + p.values(p.variable("a"))))) .matches( window(windowMatcherBuilder -> windowMatcherBuilder .specification(specificationA) @@ -177,18 +177,18 @@ public void testIntermediateProjectNodes() .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("lagOutput"), newWindowNodeFunction("lag", LAG_FUNCTION_HANDLE, "a", "one")), + ImmutableMap.of(p.variable(p.symbol("lagOutput")), newWindowNodeFunction("lag", LAG_FUNCTION_HANDLE, "a", "one")), p.project( Assignments.builder() - .put(p.symbol("one"), expression("CAST(1 AS bigint)")) - .putIdentities(ImmutableList.of(p.symbol("a"), p.symbol("avgOutput"))) + .put(p.variable("one"), expression("CAST(1 AS bigint)")) + .putIdentities(ImmutableList.of(p.variable("a"), p.variable("avgOutput"))) .build(), p.project( - Assignments.identity(p.symbol("a"), p.symbol("avgOutput"), p.symbol("unused")), + Assignments.identity(p.variable("a"), p.variable("avgOutput"), p.variable("unused")), p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avgOutput"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), - p.values(p.symbol("a"), p.symbol("unused"))))))) + ImmutableMap.of(p.variable(p.symbol("avgOutput")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), + p.values(p.variable("a"), p.variable("unused"))))))) .matches( strictProject( ImmutableMap.of( @@ -214,7 +214,7 @@ public void testIntermediateProjectNodes() private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName) { - return new WindowNode.Specification(ImmutableList.of(planBuilder.symbol(symbolName, BIGINT)), Optional.empty()); + return new WindowNode.Specification(ImmutableList.of(planBuilder.variable(planBuilder.symbol(symbolName, BIGINT))), Optional.empty()); } private WindowNode.Function newWindowNodeFunction(String name, FunctionHandle functionHandle, String... symbols) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java index 9ce71f2c884fa..fa60054df0f2a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java @@ -31,7 +31,7 @@ public void test() .on(p -> p.filter(expression("b > 44"), p.filter(expression("a < 42"), - p.values(p.symbol("a"), p.symbol("b"))))) + p.values(p.variable("a"), p.variable("b"))))) .matches(filter("(a < 42) AND (b > 44)", values(ImmutableMap.of("a", 0, "b", 1)))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java index 9741fe63546e6..386688f14e6ab 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java @@ -29,12 +29,12 @@ public void testNoDistinct() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("count(input1)"), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), expression("count(input2)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output1")), expression("count(input1)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output2")), expression("count(input2)"), ImmutableList.of(BIGINT)) .source( p.values( - p.symbol("input1"), - p.symbol("input2"))))) + p.variable("input1"), + p.variable("input2"))))) .doesNotFire(); } @@ -44,11 +44,11 @@ public void testSingleDistinct() tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output1")), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .source( p.values( - p.symbol("input1"), - p.symbol("input2"))))) + p.variable("input1"), + p.variable("input2"))))) .doesNotFire(); } @@ -58,10 +58,10 @@ public void testMultipleAggregations() tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), expression("sum(DISTINCT input)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output1")), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output2")), expression("sum(DISTINCT input)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("input"))))) + p.values(p.variable("input"))))) .doesNotFire(); } @@ -71,23 +71,23 @@ public void testDistinctWithFilter() tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), expression("count(DISTINCT input2) filter (where input1 > 0)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output1")), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output2")), expression("count(DISTINCT input2) filter (where input1 > 0)"), ImmutableList.of(BIGINT)) .source( p.values( - p.symbol("input1"), - p.symbol("input2"))))) + p.variable("input1"), + p.variable("input2"))))) .doesNotFire(); tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), expression("count(DISTINCT input2)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output1")), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output2")), expression("count(DISTINCT input2)"), ImmutableList.of(BIGINT)) .source( p.values( - p.symbol("input1"), - p.symbol("input2"))))) + p.variable("input1"), + p.variable("input2"))))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPickTableLayout.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPickTableLayout.java index 82cfd24be6e70..633f81d54fb1c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPickTableLayout.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPickTableLayout.java @@ -78,7 +78,7 @@ public void doesNotFireIfNoTableScan() { for (Rule rule : pickTableLayout.rules()) { tester().assertThat(rule) - .on(p -> p.values(p.symbol("a", BIGINT))) + .on(p -> p.values(p.variable("a", BIGINT))) .doesNotFire(); } } @@ -90,7 +90,8 @@ public void doesNotFireIfTableScanHasTableLayout() .on(p -> p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))) + ImmutableList.of(p.variable(p.symbol("nationkey", BIGINT))), + ImmutableMap.of(p.variable(p.symbol("nationkey", BIGINT)), new TpchColumnHandle("nationkey", BIGINT)))) .doesNotFire(); } @@ -102,7 +103,8 @@ public void eliminateTableScanWhenNoLayoutExist() p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), - ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) + ImmutableList.of(p.variable(p.symbol("orderstatus", createVarcharType(1)))), + ImmutableMap.of(p.variable(p.symbol("orderstatus", createVarcharType(1))), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) .matches(values("A")); } @@ -115,7 +117,8 @@ public void replaceWithExistsWhenNoLayoutExist() p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), columnHandle), + ImmutableList.of(p.variable(p.symbol("nationkey", BIGINT))), + ImmutableMap.of(p.variable(p.symbol("nationkey", BIGINT)), columnHandle), TupleDomain.none(), TupleDomain.none()))) .matches(values("A")); @@ -129,7 +132,8 @@ public void doesNotFireIfRuleNotChangePlan() p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), + ImmutableList.of(p.variable(p.symbol("nationkey", BIGINT))), + ImmutableMap.of(p.variable(p.symbol("nationkey", BIGINT)), new TpchColumnHandle("nationkey", BIGINT)), TupleDomain.all(), TupleDomain.all()))) .doesNotFire(); @@ -146,7 +150,8 @@ public void ruleAddedTableLayoutToTableScan() TestingTransactionHandle.create(), Optional.empty()), ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))) + ImmutableList.of(p.variable(p.symbol("nationkey", BIGINT))), + ImmutableMap.of(p.variable(p.symbol("nationkey", BIGINT)), new TpchColumnHandle("nationkey", BIGINT)))) .matches( constrainedTableScanWithTableLayout("nation", ImmutableMap.of(), ImmutableMap.of("nationkey", "nationkey"))); } @@ -162,7 +167,8 @@ public void ruleAddedTableLayoutToFilterTableScan() p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), - ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) + ImmutableList.of(p.variable(p.symbol("orderstatus", createVarcharType(1)))), + ImmutableMap.of(p.variable(p.symbol("orderstatus", createVarcharType(1))), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) .matches( constrainedTableScanWithTableLayout("orders", filterConstraint, ImmutableMap.of("orderstatus", "orderstatus"))); } @@ -175,7 +181,8 @@ public void ruleAddedNewTableLayoutIfTableScanHasEmptyConstraint() p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), - ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) + ImmutableList.of(p.variable(p.symbol("orderstatus", createVarcharType(1)))), + ImmutableMap.of(p.variable(p.symbol("orderstatus", createVarcharType(1))), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) .matches( constrainedTableScanWithTableLayout( "orders", @@ -192,7 +199,8 @@ public void ruleWithPushdownableToTableLayoutPredicate() p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", orderStatusType)), - ImmutableMap.of(p.symbol("orderstatus", orderStatusType), new TpchColumnHandle("orderstatus", orderStatusType))))) + ImmutableList.of(p.variable(p.symbol("orderstatus", orderStatusType))), + ImmutableMap.of(p.variable(p.symbol("orderstatus", orderStatusType)), new TpchColumnHandle("orderstatus", orderStatusType))))) .matches(constrainedTableScanWithTableLayout( "orders", ImmutableMap.of("orderstatus", singleValue(orderStatusType, utf8Slice("O"))), @@ -208,7 +216,8 @@ public void nonDeterministicPredicate() p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", orderStatusType)), - ImmutableMap.of(p.symbol("orderstatus", orderStatusType), new TpchColumnHandle("orderstatus", orderStatusType))))) + ImmutableList.of(p.variable(p.symbol("orderstatus", orderStatusType))), + ImmutableMap.of(p.variable(p.symbol("orderstatus", orderStatusType)), new TpchColumnHandle("orderstatus", orderStatusType))))) .matches( filter("rand() = 0", constrainedTableScanWithTableLayout( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java index 97f99717e4a52..cddb388e5447f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.Assignments; @@ -42,7 +42,7 @@ public class TestPruneAggregationColumns public void testNotAllInputsReferenced() { tester().assertThat(new PruneAggregationColumns()) - .on(p -> buildProjectedAggregation(p, symbol -> symbol.getName().equals("b"))) + .on(p -> buildProjectedAggregation(p, variable -> variable.getName().equals("b"))) .matches( strictProject( ImmutableMap.of("b", expression("b")), @@ -65,11 +65,11 @@ public void testAllOutputsReferenced() .doesNotFire(); } - private ProjectNode buildProjectedAggregation(PlanBuilder planBuilder, Predicate projectionFilter) + private ProjectNode buildProjectedAggregation(PlanBuilder planBuilder, Predicate projectionFilter) { - Symbol a = planBuilder.symbol("a"); - Symbol b = planBuilder.symbol("b"); - Symbol key = planBuilder.symbol("key"); + VariableReferenceExpression a = planBuilder.variable("a"); + VariableReferenceExpression b = planBuilder.variable("b"); + VariableReferenceExpression key = planBuilder.variable("key"); return planBuilder.project( Assignments.identity(ImmutableList.of(a, b).stream().filter(projectionFilter).collect(toImmutableSet())), planBuilder.aggregation(aggregationBuilder -> aggregationBuilder diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java index b627b0ac6c11f..3307687fb5342 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; @@ -72,22 +73,26 @@ public void testAllInputsReferenced() private AggregationNode buildAggregation(PlanBuilder planBuilder, Predicate sourceSymbolFilter) { - Symbol avg = planBuilder.symbol("avg"); + VariableReferenceExpression avg = planBuilder.variable(planBuilder.symbol("avg")); Symbol input = planBuilder.symbol("input"); Symbol key = planBuilder.symbol("key"); Symbol keyHash = planBuilder.symbol("keyHash"); Symbol mask = planBuilder.symbol("mask"); Symbol unused = planBuilder.symbol("unused"); - List sourceSymbols = ImmutableList.of(input, key, keyHash, mask, unused); + List filteredSourceSymboles = ImmutableList.of(input, key, keyHash, mask, unused).stream() + .filter(sourceSymbolFilter) + .collect(toImmutableList()); + List filteredSourceVariables = filteredSourceSymboles.stream() + .map(planBuilder::variable) + .collect(toImmutableList()); + return planBuilder.aggregation(aggregationBuilder -> aggregationBuilder - .singleGroupingSet(key) - .addAggregation(avg, planBuilder.expression("avg(input)"), ImmutableList.of(BIGINT), mask) - .hashSymbol(keyHash) + .singleGroupingSet(planBuilder.variable(key)) + .addAggregation(avg, planBuilder.expression("avg(input)"), ImmutableList.of(BIGINT), planBuilder.variable(mask)) + .hashVariable(planBuilder.variable(keyHash)) .source( planBuilder.values( - sourceSymbols.stream() - .filter(sourceSymbolFilter) - .collect(toImmutableList()), + filteredSourceVariables, ImmutableList.of()))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java index 83564fed1a1fa..b9354148e9ae1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java @@ -15,6 +15,7 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; @@ -49,11 +50,11 @@ public void testDoesNotFireOnNonNestedAggregate() p.aggregation((a) -> a .globalGrouping() .addAggregation( - p.symbol("count_1", BIGINT), + p.variable(p.symbol("count_1", BIGINT)), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BIGINT)) .source( - p.tableScan(ImmutableList.of(), ImmutableMap.of()))) + p.tableScan(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()))) ).doesNotFire(); } @@ -64,13 +65,13 @@ public void testFiresOnNestedCountAggregate() .on(p -> p.aggregation((a) -> a .addAggregation( - p.symbol("count_1", BIGINT), + p.variable(p.symbol("count_1", BIGINT)), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BIGINT)) .globalGrouping() .step(AggregationNode.Step.SINGLE) .source( p.aggregation((aggregationBuilder) -> aggregationBuilder - .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())) + .source(p.tableScan(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) .globalGrouping() .step(AggregationNode.Step.SINGLE))))) .matches(values(ImmutableMap.of("count_1", 0))); @@ -83,13 +84,13 @@ public void testFiresOnCountAggregateOverValues() .on(p -> p.aggregation((a) -> a .addAggregation( - p.symbol("count_1", BIGINT), + p.variable(p.symbol("count_1", BIGINT)), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source(p.values( - ImmutableList.of(p.symbol("orderkey")), + ImmutableList.of(p.variable(p.symbol("orderkey"))), ImmutableList.of(PlanBuilder.constantExpressions(BIGINT, 1)))))) .matches(values(ImmutableMap.of("count_1", 0))); } @@ -101,12 +102,12 @@ public void testFiresOnCountAggregateOverEnforceSingleRow() .on(p -> p.aggregation((a) -> a .addAggregation( - p.symbol("count_1", BIGINT), + p.variable(p.symbol("count_1", BIGINT)), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() - .source(p.enforceSingleRow(p.tableScan(ImmutableList.of(), ImmutableMap.of()))))) + .source(p.enforceSingleRow(p.tableScan(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()))))) .matches(values(ImmutableMap.of("count_1", 0))); } @@ -117,7 +118,7 @@ public void testDoesNotFireOnNestedCountAggregateWithNonEmptyGroupBy() .on(p -> p.aggregation((a) -> a .addAggregation( - p.symbol("count_1", BIGINT), + p.variable(p.symbol("count_1", BIGINT)), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BIGINT)) .step(AggregationNode.Step.SINGLE) @@ -125,9 +126,9 @@ public void testDoesNotFireOnNestedCountAggregateWithNonEmptyGroupBy() .source( p.aggregation(aggregationBuilder -> { aggregationBuilder - .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())).groupingSets(singleGroupingSet(ImmutableList.of(p.symbol("orderkey")))); + .source(p.tableScan(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())).groupingSets(singleGroupingSet(ImmutableList.of(p.variable("orderkey")))); aggregationBuilder - .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())); + .source(p.tableScan(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())); })))) .doesNotFire(); } @@ -138,14 +139,15 @@ public void testDoesNotFireOnNestedNonCountAggregate() tester().assertThat(new PruneCountAggregationOverScalar(getFunctionManager())) .on(p -> { Symbol totalPrice = p.symbol("total_price", DOUBLE); + VariableReferenceExpression totalPriceVariable = new VariableReferenceExpression(totalPrice.getName(), DOUBLE); AggregationNode inner = p.aggregation((a) -> a - .addAggregation(totalPrice, + .addAggregation(totalPriceVariable, new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(new SymbolReference("totalprice"))), ImmutableList.of(DOUBLE)) .globalGrouping() .source( p.project( - Assignments.of(totalPrice, totalPrice.toSymbolReference()), + Assignments.of(totalPriceVariable, totalPrice.toSymbolReference()), p.tableScan( new TableHandle( new ConnectorId("local"), @@ -153,11 +155,12 @@ public void testDoesNotFireOnNestedNonCountAggregate() TestingTransactionHandle.create(), Optional.empty()), ImmutableList.of(totalPrice), - ImmutableMap.of(totalPrice, new TpchColumnHandle(totalPrice.getName(), DOUBLE)))))); + ImmutableList.of(totalPriceVariable), + ImmutableMap.of(totalPriceVariable, new TpchColumnHandle(totalPrice.getName(), DOUBLE)))))); return p.aggregation((a) -> a .addAggregation( - p.symbol("sum_outer", DOUBLE), + p.variable(p.symbol("sum_outer", DOUBLE)), new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(new SymbolReference("sum_inner"))), ImmutableList.of(DOUBLE)) .globalGrouping() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java index 3a68548a4398f..1bb90bd3115f1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; @@ -41,7 +41,7 @@ public class TestPruneCrossJoinColumns public void testLeftInputNotReferenced() { tester().assertThat(new PruneCrossJoinColumns()) - .on(p -> buildProjectedCrossJoin(p, symbol -> symbol.getName().equals("rightValue"))) + .on(p -> buildProjectedCrossJoin(p, variable -> variable.getName().equals("rightValue"))) .matches( strictProject( ImmutableMap.of("rightValue", PlanMatchPattern.expression("rightValue")), @@ -60,7 +60,7 @@ public void testLeftInputNotReferenced() public void testRightInputNotReferenced() { tester().assertThat(new PruneCrossJoinColumns()) - .on(p -> buildProjectedCrossJoin(p, symbol -> symbol.getName().equals("leftValue"))) + .on(p -> buildProjectedCrossJoin(p, variable -> variable.getName().equals("leftValue"))) .matches( strictProject( ImmutableMap.of("leftValue", PlanMatchPattern.expression("leftValue")), @@ -83,11 +83,11 @@ public void testAllInputsReferenced() .doesNotFire(); } - private static PlanNode buildProjectedCrossJoin(PlanBuilder p, Predicate projectionFilter) + private static PlanNode buildProjectedCrossJoin(PlanBuilder p, Predicate projectionFilter) { - Symbol leftValue = p.symbol("leftValue"); - Symbol rightValue = p.symbol("rightValue"); - List outputs = ImmutableList.of(leftValue, rightValue); + VariableReferenceExpression leftValue = p.variable("leftValue"); + VariableReferenceExpression rightValue = p.variable("rightValue"); + List outputs = ImmutableList.of(leftValue, rightValue); return p.project( Assignments.identity( outputs.stream() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java index 3c3df254f4370..d59d696198f69 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.Assignments; @@ -38,7 +38,7 @@ public class TestPruneFilterColumns public void testNotAllInputsReferenced() { tester().assertThat(new PruneFilterColumns()) - .on(p -> buildProjectedFilter(p, symbol -> symbol.getName().equals("b"))) + .on(p -> buildProjectedFilter(p, variable -> variable.getName().equals("b"))) .matches( strictProject( ImmutableMap.of("b", expression("b")), @@ -65,10 +65,10 @@ public void testAllOutputsReferenced() .doesNotFire(); } - private ProjectNode buildProjectedFilter(PlanBuilder planBuilder, Predicate projectionFilter) + private ProjectNode buildProjectedFilter(PlanBuilder planBuilder, Predicate projectionFilter) { - Symbol a = planBuilder.symbol("a"); - Symbol b = planBuilder.symbol("b"); + VariableReferenceExpression a = planBuilder.variable("a"); + VariableReferenceExpression b = planBuilder.variable("b"); return planBuilder.project( Assignments.identity(Stream.of(a, b).filter(projectionFilter).collect(toImmutableSet())), planBuilder.filter( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java index a307145f798a2..205c2999bc69b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java @@ -18,7 +18,7 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.Assignments; @@ -51,7 +51,7 @@ public class TestPruneIndexSourceColumns public void testNotAllOutputsReferenced() { tester().assertThat(new PruneIndexSourceColumns()) - .on(p -> buildProjectedIndexSource(p, symbol -> symbol.getName().equals("orderkey"))) + .on(p -> buildProjectedIndexSource(p, variable -> variable.getName().equals("orderkey"))) .matches( strictProject( ImmutableMap.of("x", expression("orderkey")), @@ -71,11 +71,11 @@ public void testAllOutputsReferenced() .doesNotFire(); } - private static PlanNode buildProjectedIndexSource(PlanBuilder p, Predicate projectionFilter) + private static PlanNode buildProjectedIndexSource(PlanBuilder p, Predicate projectionFilter) { - Symbol orderkey = p.symbol("orderkey", INTEGER); - Symbol custkey = p.symbol("custkey", INTEGER); - Symbol totalprice = p.symbol("totalprice", DOUBLE); + VariableReferenceExpression orderkey = p.variable("orderkey", INTEGER); + VariableReferenceExpression custkey = p.variable("custkey", INTEGER); + VariableReferenceExpression totalprice = p.variable("totalprice", DOUBLE); ColumnHandle orderkeyHandle = new TpchColumnHandle(orderkey.getName(), INTEGER); ColumnHandle custkeyHandle = new TpchColumnHandle(custkey.getName(), INTEGER); ColumnHandle totalpriceHandle = new TpchColumnHandle(totalprice.getName(), DOUBLE); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java index 935ee820db41a..130c0be218e6a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; @@ -42,7 +42,7 @@ public class TestPruneJoinChildrenColumns public void testNotAllInputsRereferenced() { tester().assertThat(new PruneJoinChildrenColumns()) - .on(p -> buildJoin(p, symbol -> symbol.getName().equals("leftValue"))) + .on(p -> buildJoin(p, variable -> variable.getName().equals("leftValue"))) .matches( join( JoinNode.Type.INNER, @@ -69,8 +69,8 @@ public void testCrossJoinDoesNotFire() { tester().assertThat(new PruneJoinColumns()) .on(p -> { - Symbol leftValue = p.symbol("leftValue"); - Symbol rightValue = p.symbol("rightValue"); + VariableReferenceExpression leftValue = p.variable("leftValue"); + VariableReferenceExpression rightValue = p.variable("rightValue"); return p.join( JoinNode.Type.INNER, p.values(leftValue), @@ -84,15 +84,15 @@ public void testCrossJoinDoesNotFire() .doesNotFire(); } - private static PlanNode buildJoin(PlanBuilder p, Predicate joinOutputFilter) + private static PlanNode buildJoin(PlanBuilder p, Predicate joinOutputFilter) { - Symbol leftKey = p.symbol("leftKey"); - Symbol leftKeyHash = p.symbol("leftKeyHash"); - Symbol leftValue = p.symbol("leftValue"); - Symbol rightKey = p.symbol("rightKey"); - Symbol rightKeyHash = p.symbol("rightKeyHash"); - Symbol rightValue = p.symbol("rightValue"); - List outputs = ImmutableList.of(leftValue, rightValue); + VariableReferenceExpression leftKey = p.variable("leftKey"); + VariableReferenceExpression leftKeyHash = p.variable("leftKeyHash"); + VariableReferenceExpression leftValue = p.variable("leftValue"); + VariableReferenceExpression rightKey = p.variable("rightKey"); + VariableReferenceExpression rightKeyHash = p.variable("rightKeyHash"); + VariableReferenceExpression rightValue = p.variable("rightValue"); + List outputs = ImmutableList.of(leftValue, rightValue); return p.join( JoinNode.Type.INNER, p.values(leftKey, leftKeyHash, leftValue), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java index 7bfcc8337221d..d697f82180222 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; @@ -42,7 +42,7 @@ public class TestPruneJoinColumns public void testNotAllOutputsReferenced() { tester().assertThat(new PruneJoinColumns()) - .on(p -> buildProjectedJoin(p, symbol -> symbol.getName().equals("rightValue"))) + .on(p -> buildProjectedJoin(p, variable -> variable.getName().equals("rightValue"))) .matches( strictProject( ImmutableMap.of("rightValue", PlanMatchPattern.expression("rightValue")), @@ -68,8 +68,8 @@ public void testCrossJoinDoesNotFire() { tester().assertThat(new PruneJoinColumns()) .on(p -> { - Symbol leftValue = p.symbol("leftValue"); - Symbol rightValue = p.symbol("rightValue"); + VariableReferenceExpression leftValue = p.variable("leftValue"); + VariableReferenceExpression rightValue = p.variable("rightValue"); return p.project( Assignments.of(), p.join( @@ -85,13 +85,13 @@ public void testCrossJoinDoesNotFire() .doesNotFire(); } - private static PlanNode buildProjectedJoin(PlanBuilder p, Predicate projectionFilter) + private static PlanNode buildProjectedJoin(PlanBuilder p, Predicate projectionFilter) { - Symbol leftKey = p.symbol("leftKey"); - Symbol leftValue = p.symbol("leftValue"); - Symbol rightKey = p.symbol("rightKey"); - Symbol rightValue = p.symbol("rightValue"); - List outputs = ImmutableList.of(leftKey, leftValue, rightKey, rightValue); + VariableReferenceExpression leftKey = p.variable("leftKey"); + VariableReferenceExpression leftValue = p.variable("leftValue"); + VariableReferenceExpression rightKey = p.variable("rightKey"); + VariableReferenceExpression rightValue = p.variable("rightValue"); + List outputs = ImmutableList.of(leftKey, leftValue, rightKey, rightValue); return p.project( Assignments.identity( outputs.stream() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java index a2c79c372da49..2f41ac6646332 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.Assignments; @@ -38,7 +38,7 @@ public class TestPruneLimitColumns public void testNotAllInputsReferenced() { tester().assertThat(new PruneLimitColumns()) - .on(p -> buildProjectedLimit(p, symbol -> symbol.getName().equals("b"))) + .on(p -> buildProjectedLimit(p, variable -> variable.getName().equals("b"))) .matches( strictProject( ImmutableMap.of("b", expression("b")), @@ -57,10 +57,10 @@ public void testAllOutputsReferenced() .doesNotFire(); } - private ProjectNode buildProjectedLimit(PlanBuilder planBuilder, Predicate projectionFilter) + private ProjectNode buildProjectedLimit(PlanBuilder planBuilder, Predicate projectionFilter) { - Symbol a = planBuilder.symbol("a"); - Symbol b = planBuilder.symbol("b"); + VariableReferenceExpression a = planBuilder.variable("a"); + VariableReferenceExpression b = planBuilder.variable("b"); return planBuilder.project( Assignments.identity(Stream.of(a, b).filter(projectionFilter).collect(toImmutableSet())), planBuilder.limit(1, planBuilder.values(a, b))); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java index d474e0dbc3ce8..3840d0830f5ff 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java @@ -13,9 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -34,12 +35,12 @@ public void testMarkerSymbolNotReferenced() tester().assertThat(new PruneMarkDistinctColumns()) .on(p -> { - Symbol key = p.symbol("key"); - Symbol key2 = p.symbol("key2"); - Symbol mark = p.symbol("mark"); - Symbol unused = p.symbol("unused"); + VariableReferenceExpression key = p.variable("key"); + VariableReferenceExpression key2 = p.variable("key2"); + VariableReferenceExpression mark = p.variable("mark"); + VariableReferenceExpression unused = p.variable("unused"); return p.project( - Assignments.of(key2, key.toSymbolReference()), + Assignments.of(key2, new SymbolReference(key.getName())), p.markDistinct(mark, ImmutableList.of(key), p.values(key, unused))); }) .matches( @@ -54,10 +55,10 @@ public void testSourceSymbolNotReferenced() tester().assertThat(new PruneMarkDistinctColumns()) .on(p -> { - Symbol key = p.symbol("key"); - Symbol mark = p.symbol("mark"); - Symbol hash = p.symbol("hash"); - Symbol unused = p.symbol("unused"); + VariableReferenceExpression key = p.variable("key"); + VariableReferenceExpression mark = p.variable("mark"); + VariableReferenceExpression hash = p.variable("hash"); + VariableReferenceExpression unused = p.variable("unused"); return p.project( Assignments.identity(mark), p.markDistinct( @@ -83,8 +84,8 @@ public void testKeySymbolNotReferenced() tester().assertThat(new PruneMarkDistinctColumns()) .on(p -> { - Symbol key = p.symbol("key"); - Symbol mark = p.symbol("mark"); + VariableReferenceExpression key = p.variable("key"); + VariableReferenceExpression mark = p.variable("mark"); return p.project( Assignments.identity(mark), p.markDistinct(mark, ImmutableList.of(key), p.values(key))); @@ -98,8 +99,8 @@ public void testAllOutputsReferenced() tester().assertThat(new PruneMarkDistinctColumns()) .on(p -> { - Symbol key = p.symbol("key"); - Symbol mark = p.symbol("mark"); + VariableReferenceExpression key = p.variable("key"); + VariableReferenceExpression mark = p.variable("mark"); return p.project( Assignments.identity(key, mark), p.markDistinct(mark, ImmutableList.of(key), p.values(key))); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java index 3234bb3bfeea1..c50d994f7a36d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java @@ -15,6 +15,7 @@ import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; @@ -62,18 +63,23 @@ public void testBasics() private AggregationNode buildAggregation(PlanBuilder planBuilder) { - Symbol avg = planBuilder.symbol("avg"); - Symbol arrayAgg = planBuilder.symbol("array_agg"); + VariableReferenceExpression avg = planBuilder.variable(planBuilder.symbol("avg")); + VariableReferenceExpression arrayAgg = planBuilder.variable(planBuilder.symbol("array_agg")); Symbol input = planBuilder.symbol("input"); Symbol key = planBuilder.symbol("key"); Symbol keyHash = planBuilder.symbol("keyHash"); Symbol mask = planBuilder.symbol("mask"); List sourceSymbols = ImmutableList.of(input, key, keyHash, mask); + List sourceVariables = ImmutableList.of( + planBuilder.variable(input), + planBuilder.variable(key), + planBuilder.variable(keyHash), + planBuilder.variable(mask)); return planBuilder.aggregation(aggregationBuilder -> aggregationBuilder - .singleGroupingSet(key) - .addAggregation(avg, planBuilder.expression("avg(input order by input)"), ImmutableList.of(BIGINT), mask) - .addAggregation(arrayAgg, planBuilder.expression("array_agg(input order by input)"), ImmutableList.of(BIGINT), mask) - .hashSymbol(keyHash) - .source(planBuilder.values(sourceSymbols, ImmutableList.of()))); + .singleGroupingSet(planBuilder.variable(key)) + .addAggregation(avg, planBuilder.expression("avg(input order by input)"), ImmutableList.of(BIGINT), planBuilder.variable(mask)) + .addAggregation(arrayAgg, planBuilder.expression("array_agg(input order by input)"), ImmutableList.of(BIGINT), planBuilder.variable(mask)) + .hashVariable(planBuilder.variable(keyHash)) + .source(planBuilder.values(sourceVariables, ImmutableList.of()))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOutputColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOutputColumns.java index 35c934f3e734e..2a6a8536e2409 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOutputColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOutputColumns.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableList; @@ -33,10 +34,11 @@ public void testNotAllOutputsReferenced() tester().assertThat(new PruneOutputColumns()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); return p.output( ImmutableList.of("B label"), + ImmutableList.of(new Symbol("b")), ImmutableList.of(b), p.values(a, b)); }) @@ -54,10 +56,11 @@ public void testAllOutputsReferenced() tester().assertThat(new PruneOutputColumns()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); return p.output( ImmutableList.of("A label", "B label"), + ImmutableList.of(new Symbol("a"), new Symbol("b")), ImmutableList.of(a, b), p.values(a, b)); }) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneProjectColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneProjectColumns.java index 6cab5811bfc0a..6c1ec76d36a72 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneProjectColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneProjectColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; import com.google.common.collect.ImmutableMap; @@ -31,8 +31,8 @@ public void testNotAllOutputsReferenced() { tester().assertThat(new PruneProjectColumns()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); return p.project( Assignments.identity(b), p.project( @@ -52,8 +52,8 @@ public void testAllOutputsReferenced() { tester().assertThat(new PruneProjectColumns()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); return p.project( Assignments.identity(b), p.project( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java index 4dc7a637e27ac..fb57688ec127b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.Assignments; @@ -39,7 +39,7 @@ public class TestPruneSemiJoinColumns public void testSemiJoinNotNeeded() { tester().assertThat(new PruneSemiJoinColumns()) - .on(p -> buildProjectedSemiJoin(p, symbol -> symbol.getName().equals("leftValue"))) + .on(p -> buildProjectedSemiJoin(p, variable -> variable.getName().equals("leftValue"))) .matches( strictProject( ImmutableMap.of("leftValue", expression("leftValue")), @@ -50,7 +50,7 @@ public void testSemiJoinNotNeeded() public void testAllColumnsNeeded() { tester().assertThat(new PruneSemiJoinColumns()) - .on(p -> buildProjectedSemiJoin(p, symbol -> true)) + .on(p -> buildProjectedSemiJoin(p, variable -> true)) .doesNotFire(); } @@ -58,7 +58,7 @@ public void testAllColumnsNeeded() public void testKeysNotNeeded() { tester().assertThat(new PruneSemiJoinColumns()) - .on(p -> buildProjectedSemiJoin(p, symbol -> (symbol.getName().equals("leftValue") || symbol.getName().equals("match")))) + .on(p -> buildProjectedSemiJoin(p, variable -> (variable.getName().equals("leftValue") || variable.getName().equals("match")))) .doesNotFire(); } @@ -66,7 +66,7 @@ public void testKeysNotNeeded() public void testValueNotNeeded() { tester().assertThat(new PruneSemiJoinColumns()) - .on(p -> buildProjectedSemiJoin(p, symbol -> symbol.getName().equals("match"))) + .on(p -> buildProjectedSemiJoin(p, variable -> variable.getName().equals("match"))) .matches( strictProject( ImmutableMap.of("match", expression("match")), @@ -79,14 +79,14 @@ public void testValueNotNeeded() values("rightKey")))); } - private static PlanNode buildProjectedSemiJoin(PlanBuilder p, Predicate projectionFilter) + private static PlanNode buildProjectedSemiJoin(PlanBuilder p, Predicate projectionFilter) { - Symbol match = p.symbol("match"); - Symbol leftKey = p.symbol("leftKey"); - Symbol leftKeyHash = p.symbol("leftKeyHash"); - Symbol leftValue = p.symbol("leftValue"); - Symbol rightKey = p.symbol("rightKey"); - List outputs = ImmutableList.of(match, leftKey, leftKeyHash, leftValue); + VariableReferenceExpression match = p.variable("match"); + VariableReferenceExpression leftKey = p.variable("leftKey"); + VariableReferenceExpression leftKeyHash = p.variable("leftKeyHash"); + VariableReferenceExpression leftValue = p.variable("leftValue"); + VariableReferenceExpression rightKey = p.variable("rightKey"); + List outputs = ImmutableList.of(match, leftKey, leftKeyHash, leftValue); return p.project( Assignments.identity( outputs.stream() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java index d64c103f1e96c..a325275b571e0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -38,7 +38,7 @@ public class TestPruneSemiJoinFilteringSourceColumns public void testNotAllColumnsReferenced() { tester().assertThat(new PruneSemiJoinFilteringSourceColumns()) - .on(p -> buildSemiJoin(p, symbol -> true)) + .on(p -> buildSemiJoin(p, variable -> true)) .matches( semiJoin("leftKey", "rightKey", "match", values("leftKey"), @@ -53,18 +53,20 @@ public void testNotAllColumnsReferenced() public void testAllColumnsNeeded() { tester().assertThat(new PruneSemiJoinFilteringSourceColumns()) - .on(p -> buildSemiJoin(p, symbol -> !symbol.getName().equals("rightValue"))) + .on(p -> buildSemiJoin(p, variable -> !variable.getName().equals("rightValue"))) .doesNotFire(); } - private static PlanNode buildSemiJoin(PlanBuilder p, Predicate filteringSourceSymbolFilter) + private static PlanNode buildSemiJoin(PlanBuilder p, Predicate filteringSourceVariableFilter) { - Symbol match = p.symbol("match"); - Symbol leftKey = p.symbol("leftKey"); - Symbol rightKey = p.symbol("rightKey"); - Symbol rightKeyHash = p.symbol("rightKeyHash"); - Symbol rightValue = p.symbol("rightValue"); - List filteringSourceSymbols = ImmutableList.of(rightKey, rightKeyHash, rightValue); + VariableReferenceExpression match = p.variable("match"); + VariableReferenceExpression leftKey = p.variable("leftKey"); + VariableReferenceExpression rightKey = p.variable("rightKey"); + VariableReferenceExpression rightKeyHash = p.variable("rightKeyHash"); + VariableReferenceExpression rightValue = p.variable("rightValue"); + List filteringSourceVariables = ImmutableList.of(rightKey, rightKeyHash, rightValue); + List filteredSourceVariables = filteringSourceVariables.stream().filter(filteringSourceVariableFilter).collect(toImmutableList()); + return p.semiJoin( leftKey, rightKey, @@ -73,9 +75,7 @@ private static PlanNode buildSemiJoin(PlanBuilder p, Predicate filtering Optional.of(rightKeyHash), p.values(leftKey), p.values( - filteringSourceSymbols.stream() - .filter(filteringSourceSymbolFilter) - .collect(toImmutableList()), + filteredSourceVariables, ImmutableList.of())); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java index 7a74c01231e87..bddb82cbec52f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java @@ -15,6 +15,7 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; @@ -46,9 +47,11 @@ public void testNotAllOutputsReferenced() .on(p -> { Symbol orderdate = p.symbol("orderdate", DATE); + VariableReferenceExpression orderdateVariable = new VariableReferenceExpression(orderdate.getName(), DATE); Symbol totalprice = p.symbol("totalprice", DOUBLE); + VariableReferenceExpression totalpriceVariable = new VariableReferenceExpression(totalprice.getName(), DOUBLE); return p.project( - Assignments.of(p.symbol("x"), totalprice.toSymbolReference()), + Assignments.of(p.variable("x"), totalprice.toSymbolReference()), p.tableScan( new TableHandle( new ConnectorId("local"), @@ -56,9 +59,10 @@ public void testNotAllOutputsReferenced() TestingTransactionHandle.create(), Optional.empty()), ImmutableList.of(orderdate, totalprice), + ImmutableList.of(orderdateVariable, totalpriceVariable), ImmutableMap.of( - orderdate, new TpchColumnHandle(orderdate.getName(), DATE), - totalprice, new TpchColumnHandle(totalprice.getName(), DOUBLE)))); + orderdateVariable, new TpchColumnHandle(orderdate.getName(), DATE), + totalpriceVariable, new TpchColumnHandle(totalprice.getName(), DOUBLE)))); }) .matches( strictProject( @@ -70,12 +74,16 @@ orderdate, new TpchColumnHandle(orderdate.getName(), DATE), public void testAllOutputsReferenced() { tester().assertThat(new PruneTableScanColumns()) - .on(p -> - p.project( - Assignments.of(p.symbol("y"), expression("x")), - p.tableScan( - ImmutableList.of(p.symbol("x")), - ImmutableMap.of(p.symbol("x"), new TestingColumnHandle("x"))))) + .on(p -> { + Symbol x = p.symbol("x"); + VariableReferenceExpression xv = p.variable(x); + return p.project( + Assignments.of(p.variable("y"), expression("x")), + p.tableScan( + ImmutableList.of(x), + ImmutableList.of(xv), + ImmutableMap.of(p.variable(p.symbol("x")), new TestingColumnHandle("x")))); + }) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java index 486ef116d2c4b..f994bda05f389 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.Assignments; @@ -43,7 +43,7 @@ public class TestPruneTopNColumns public void testNotAllInputsReferenced() { tester().assertThat(new PruneTopNColumns()) - .on(p -> buildProjectedTopN(p, symbol -> symbol.getName().equals("b"))) + .on(p -> buildProjectedTopN(p, variable -> variable.getName().equals("b"))) .matches( strictProject( ImmutableMap.of("b", expression("b")), @@ -59,7 +59,7 @@ public void testNotAllInputsReferenced() public void testAllInputsRereferenced() { tester().assertThat(new PruneTopNColumns()) - .on(p -> buildProjectedTopN(p, symbol -> symbol.getName().equals("a"))) + .on(p -> buildProjectedTopN(p, variable -> variable.getName().equals("a"))) .doesNotFire(); } @@ -71,10 +71,10 @@ public void testAllOutputsReferenced() .doesNotFire(); } - private ProjectNode buildProjectedTopN(PlanBuilder planBuilder, Predicate projectionTopN) + private ProjectNode buildProjectedTopN(PlanBuilder planBuilder, Predicate projectionTopN) { - Symbol a = planBuilder.symbol("a"); - Symbol b = planBuilder.symbol("b"); + VariableReferenceExpression a = planBuilder.variable("a"); + VariableReferenceExpression b = planBuilder.variable("b"); return planBuilder.project( Assignments.identity(ImmutableList.of(a, b).stream().filter(projectionTopN).collect(toImmutableSet())), planBuilder.topN( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java index e21c5d253edfd..1732c8a3e084c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java @@ -35,9 +35,9 @@ public void testNotAllOutputsReferenced() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y"), expression("x")), + Assignments.of(p.variable("y"), expression("x")), p.values( - ImmutableList.of(p.symbol("unused"), p.symbol("x")), + ImmutableList.of(p.variable("unused"), p.variable("x")), ImmutableList.of( constantExpressions(BIGINT, 1, 2), constantExpressions(BIGINT, 3, 4))))) @@ -57,8 +57,8 @@ public void testAllOutputsReferenced() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y"), expression("x")), - p.values(p.symbol("x")))) + Assignments.of(p.variable("y"), expression("x")), + p.values(p.variable("x")))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java index afbc7a33ecb6e..aa606732cfff0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -17,7 +17,6 @@ import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; @@ -26,6 +25,7 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -184,23 +184,25 @@ public void testUnusedInputNotNeeded() private static PlanNode buildProjectedWindow( PlanBuilder p, - Predicate projectionFilter, - Predicate sourceFilter) + Predicate projectionFilter, + Predicate sourceFilter) { - Symbol orderKey = p.symbol("orderKey"); - Symbol partitionKey = p.symbol("partitionKey"); - Symbol hash = p.symbol("hash"); - Symbol startValue1 = p.symbol("startValue1"); - Symbol startValue2 = p.symbol("startValue2"); - Symbol endValue1 = p.symbol("endValue1"); - Symbol endValue2 = p.symbol("endValue2"); - Symbol input1 = p.symbol("input1"); - Symbol input2 = p.symbol("input2"); - Symbol unused = p.symbol("unused"); - Symbol output1 = p.symbol("output1"); - Symbol output2 = p.symbol("output2"); - List inputs = ImmutableList.of(orderKey, partitionKey, hash, startValue1, startValue2, endValue1, endValue2, input1, input2, unused); - List outputs = ImmutableList.builder().addAll(inputs).add(output1, output2).build(); + VariableReferenceExpression orderKey = p.variable("orderKey"); + VariableReferenceExpression partitionKey = p.variable("partitionKey"); + VariableReferenceExpression hash = p.variable("hash"); + VariableReferenceExpression startValue1 = p.variable("startValue1"); + VariableReferenceExpression startValue2 = p.variable("startValue2"); + VariableReferenceExpression endValue1 = p.variable("endValue1"); + VariableReferenceExpression endValue2 = p.variable("endValue2"); + VariableReferenceExpression input1 = p.variable("input1"); + VariableReferenceExpression input2 = p.variable("input2"); + VariableReferenceExpression unused = p.variable("unused"); + VariableReferenceExpression output1 = p.variable("output1"); + VariableReferenceExpression output2 = p.variable("output2"); + List inputs = ImmutableList.of(orderKey, partitionKey, hash, startValue1, startValue2, endValue1, endValue2, input1, input2, unused); + List outputs = ImmutableList.builder().addAll(inputs).add(output1, output2).build(); + + List filteredInputs = inputs.stream().filter(sourceFilter).collect(toImmutableList()); return p.project( Assignments.identity( @@ -216,31 +218,29 @@ private static PlanNode buildProjectedWindow( ImmutableMap.of( output1, new WindowNode.Function( - call(FUNCTION_NAME, FUNCTION_HANDLE, BIGINT, new VariableReferenceExpression(input1.getName(), BIGINT)), + call(FUNCTION_NAME, FUNCTION_HANDLE, BIGINT, input1), new WindowNode.Frame( RANGE, UNBOUNDED_PRECEDING, Optional.of(startValue1), CURRENT_ROW, Optional.of(endValue1), - Optional.of(startValue1.toSymbolReference()).map(Expression::toString), - Optional.of(endValue2.toSymbolReference()).map(Expression::toString))), + Optional.of(new SymbolReference(startValue1.getName())).map(Expression::toString), + Optional.of(new SymbolReference(endValue2.getName())).map(Expression::toString))), output2, new WindowNode.Function( - call(FUNCTION_NAME, FUNCTION_HANDLE, BIGINT, new VariableReferenceExpression(input2.getName(), BIGINT)), + call(FUNCTION_NAME, FUNCTION_HANDLE, BIGINT, input2), new WindowNode.Frame( RANGE, UNBOUNDED_PRECEDING, Optional.of(startValue2), CURRENT_ROW, Optional.of(endValue2), - Optional.of(startValue2.toSymbolReference()).map(Expression::toString), - Optional.of(endValue2.toSymbolReference()).map(Expression::toString)))), + Optional.of(new SymbolReference(startValue2.getName())).map(Expression::toString), + Optional.of(new SymbolReference(endValue2.getName())).map(Expression::toString)))), hash, p.values( - inputs.stream() - .filter(sourceFilter) - .collect(toImmutableList()), + filteredInputs, ImmutableList.of()))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java index f72d1abc976b1..2364a63c19789 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.Assignments; @@ -53,15 +52,15 @@ public void testPushesAggregationThroughLeftJoin() .source( p.join( JoinNode.Type.LEFT, - p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(constantExpressions(BIGINT, 10))), - p.values(p.symbol("COL2")), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL2"))), - ImmutableList.of(p.symbol("COL1"), p.symbol("COL2")), + p.values(ImmutableList.of(p.variable("COL1")), ImmutableList.of(constantExpressions(BIGINT, 10))), + p.values(p.variable("COL2")), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("COL1"), p.variable("COL2"))), + ImmutableList.of(p.variable("COL1"), p.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) - .singleGroupingSet(p.symbol("COL1")))) + .addAggregation(p.variable("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.variable("COL1")))) .matches( project(ImmutableMap.of( "COL1", expression("COL1"), @@ -93,16 +92,17 @@ public void testPushesAggregationThroughLeftJoinWithOrderByFromRightSideColumn() .source( p.join( JoinNode.Type.LEFT, - p.values(ImmutableList.of(p.symbol("COL1"), p.symbol("COL3")), + p.values( + ImmutableList.of(p.variable("COL1"), p.variable("COL3")), ImmutableList.of(constantExpressions(BIGINT, 10, 20))), - p.values(p.symbol("COL2"), p.symbol("COL4")), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL2"))), - ImmutableList.of(p.symbol("COL1"), p.symbol("COL2")), + p.values(p.variable("COL2"), p.variable("COL4")), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("COL1"), p.variable("COL2"))), + ImmutableList.of(p.variable("COL1"), p.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2 ORDER BY COL4)"), ImmutableList.of(DOUBLE)) - .singleGroupingSet(p.symbol("COL1"), p.symbol("COL3")))) + .addAggregation(p.variable("AVG", DOUBLE), PlanBuilder.expression("avg(COL2 ORDER BY COL4)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.variable("COL1"), p.variable("COL3")))) .matches( project(ImmutableMap.of( "COL1", expression("COL1"), @@ -142,15 +142,15 @@ public void testPushesAggregationThroughRightJoin() .on(p -> p.aggregation(ab -> ab .source(p.join( JoinNode.Type.RIGHT, - p.values(p.symbol("COL2")), - p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(constantExpressions(BIGINT, 10))), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL2"), p.symbol("COL1"))), - ImmutableList.of(p.symbol("COL2"), p.symbol("COL1")), + p.values(p.variable("COL2")), + p.values(ImmutableList.of(p.variable("COL1")), ImmutableList.of(constantExpressions(BIGINT, 10))), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("COL2"), p.variable("COL1"))), + ImmutableList.of(p.variable("COL2"), p.variable("COL1")), Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) - .singleGroupingSet(p.symbol("COL1")))) + .addAggregation(p.variable("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.variable("COL1")))) .matches( project(ImmutableMap.of( "COALESCE", expression("coalesce(AVG, AVG_NULL)"), @@ -183,16 +183,16 @@ public void testDoesNotFireWhenNotDistinct() .source(p.join( JoinNode.Type.LEFT, p.values( - ImmutableList.of(p.symbol("COL1")), + ImmutableList.of(p.variable("COL1")), ImmutableList.of(constantExpressions(BIGINT, 10), constantExpressions(BIGINT, 11))), - p.values(new Symbol("COL2")), - ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), - ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), + p.values(p.variable("COL2")), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("COL1"), p.variable("COL2"))), + ImmutableList.of(p.variable("COL1"), p.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) - .singleGroupingSet(new Symbol("COL1")))) + .addAggregation(p.variable("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.variable("COL1")))) .doesNotFire(); // https://github.com/prestodb/presto/issues/10592 @@ -202,22 +202,22 @@ public void testDoesNotFireWhenNotDistinct() p.join( JoinNode.Type.LEFT, p.project(Assignments.builder() - .putIdentity(p.symbol("COL1", BIGINT)) + .putIdentity(p.variable("COL1", BIGINT)) .build(), p.aggregation(builder -> - builder.singleGroupingSet(p.symbol("COL1"), p.symbol("unused")) + builder.singleGroupingSet(p.variable("COL1"), p.variable("unused")) .source( p.values( - ImmutableList.of(p.symbol("COL1"), p.symbol("unused")), + ImmutableList.of(p.variable("COL1"), p.variable("unused")), ImmutableList.of(constantExpressions(BIGINT, 10, 1), constantExpressions(BIGINT, 10, 2)))))), - p.values(p.symbol("COL2")), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL2"))), - ImmutableList.of(p.symbol("COL1"), p.symbol("COL2")), + p.values(p.variable("COL2")), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("COL1"), p.variable("COL2"))), + ImmutableList.of(p.variable("COL1"), p.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) - .singleGroupingSet(p.symbol("COL1")))) + .addAggregation(p.variable("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.variable("COL1")))) .doesNotFire(); } @@ -227,15 +227,15 @@ public void testDoesNotFireWhenGroupingOnInner() tester().assertThat(new PushAggregationThroughOuterJoin(getFunctionManager())) .on(p -> p.aggregation(ab -> ab .source(p.join(JoinNode.Type.LEFT, - p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(constantExpressions(BIGINT, 10))), - p.values(new Symbol("COL2"), new Symbol("COL3")), - ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), - ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), + p.values(ImmutableList.of(p.variable("COL1")), ImmutableList.of(constantExpressions(BIGINT, 10))), + p.values(p.variable("COL2"), p.variable("COL3")), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("COL1"), p.variable("COL2"))), + ImmutableList.of(p.variable("COL1"), p.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) - .singleGroupingSet(new Symbol("COL1"), new Symbol("COL3")))) + .addAggregation(p.variable("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.variable("COL1"), p.variable("COL3")))) .doesNotFire(); } @@ -246,15 +246,15 @@ public void testDoesNotFireWhenAggregationDoesNotHaveSymbols() .on(p -> p.aggregation(ab -> ab .source(p.join( JoinNode.Type.LEFT, - p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(constantExpressions(BIGINT, 10))), - p.values(ImmutableList.of(p.symbol("COL2")), ImmutableList.of(constantExpressions(BIGINT, 20))), - ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), - ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), + p.values(ImmutableList.of(p.variable("COL1")), ImmutableList.of(constantExpressions(BIGINT, 10))), + p.values(ImmutableList.of(p.variable("COL2")), ImmutableList.of(constantExpressions(BIGINT, 20))), + ImmutableList.of(new JoinNode.EquiJoinClause(p.variable("COL1"), p.variable("COL2"))), + ImmutableList.of(p.variable("COL1"), p.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) - .addAggregation(new Symbol("SUM"), PlanBuilder.expression("sum(COL1)"), ImmutableList.of(DOUBLE)) - .singleGroupingSet(new Symbol("COL1")))) + .addAggregation(p.variable("SUM", DOUBLE), PlanBuilder.expression("sum(COL1)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.variable("COL1")))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java index 97e3a21abc986..66027b7dbfcfa 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java @@ -33,7 +33,7 @@ public void test() p.limit( 1, p.markDistinct( - p.symbol("foo"), ImmutableList.of(p.symbol("bar")), p.values()))) + p.variable(p.symbol("foo")), ImmutableList.of(p.variable("bar")), p.values()))) .matches( node(MarkDistinctNode.class, node(LimitNode.class, @@ -46,8 +46,8 @@ public void testDoesNotFire() tester().assertThat(new PushLimitThroughMarkDistinct()) .on(p -> p.markDistinct( - p.symbol("foo"), - ImmutableList.of(p.symbol("bar")), + p.variable(p.symbol("foo")), + ImmutableList.of(p.variable("bar")), p.limit( 1, p.values()))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughOuterJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughOuterJoin.java index e39b81d521d24..f8b884cdab1cf 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughOuterJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughOuterJoin.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; import com.google.common.collect.ImmutableList; @@ -34,8 +34,8 @@ public void testPushLimitThroughLeftJoin() { tester().assertThat(new PushLimitThroughOuterJoin()) .on(p -> { - Symbol leftKey = p.symbol("leftKey"); - Symbol rightKey = p.symbol("rightKey"); + VariableReferenceExpression leftKey = p.variable("leftKey"); + VariableReferenceExpression rightKey = p.variable("rightKey"); return p.limit(1, p.join( LEFT, @@ -57,8 +57,8 @@ public void testDoesNotPushThroughFullOuterJoin() { tester().assertThat(new PushLimitThroughOuterJoin()) .on(p -> { - Symbol leftKey = p.symbol("leftKey"); - Symbol rightKey = p.symbol("rightKey"); + VariableReferenceExpression leftKey = p.variable("leftKey"); + VariableReferenceExpression rightKey = p.variable("rightKey"); return p.limit(1, p.join( FULL, @@ -74,8 +74,8 @@ public void testDoNotPushWhenAlreadyLimited() { tester().assertThat(new PushLimitThroughOuterJoin()) .on(p -> { - Symbol leftKey = p.symbol("leftKey"); - Symbol rightKey = p.symbol("rightKey"); + VariableReferenceExpression leftKey = p.variable("leftKey"); + VariableReferenceExpression rightKey = p.variable("rightKey"); return p.limit(1, p.join( LEFT, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java index d6ea56e937a4b..00176923766eb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java @@ -13,9 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -33,7 +34,7 @@ public void testPushdownLimitNonIdentityProjection() { tester().assertThat(new PushLimitThroughProject()) .on(p -> { - Symbol a = p.symbol("a"); + VariableReferenceExpression a = p.variable("a"); return p.limit(1, p.project( Assignments.of(a, TRUE_LITERAL), @@ -50,10 +51,10 @@ public void testDoesntPushdownLimitThroughIdentityProjection() { tester().assertThat(new PushLimitThroughProject()) .on(p -> { - Symbol a = p.symbol("a"); + VariableReferenceExpression a = p.variable("a"); return p.limit(1, p.project( - Assignments.of(a, a.toSymbolReference()), + Assignments.of(a, new SymbolReference(a.getName())), p.values(a))); }).doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java index 7dd4711dda361..30112445450df 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java @@ -47,15 +47,15 @@ public void testPushesPartialAggregationThroughJoin() .source( p.join( INNER, - p.values(p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR"), p.symbol("LEFT_HASH")), - p.values(p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI"), p.symbol("RIGHT_GROUP_BY"), p.symbol("RIGHT_HASH")), - ImmutableList.of(new EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), - ImmutableList.of(p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR"), p.symbol("RIGHT_GROUP_BY")), + p.values(p.variable("LEFT_EQUI"), p.variable("LEFT_NON_EQUI"), p.variable("LEFT_GROUP_BY"), p.variable("LEFT_AGGR"), p.variable("LEFT_HASH")), + p.values(p.variable("RIGHT_EQUI"), p.variable("RIGHT_NON_EQUI"), p.variable("RIGHT_GROUP_BY"), p.variable("RIGHT_HASH")), + ImmutableList.of(new EquiJoinClause(p.variable("LEFT_EQUI"), p.variable("RIGHT_EQUI"))), + ImmutableList.of(p.variable("LEFT_GROUP_BY"), p.variable("LEFT_AGGR"), p.variable("RIGHT_GROUP_BY")), Optional.of(expression("LEFT_NON_EQUI <= RIGHT_NON_EQUI")), - Optional.of(p.symbol("LEFT_HASH")), - Optional.of(p.symbol("RIGHT_HASH")))) - .addAggregation(p.symbol("AVG", DOUBLE), expression("AVG(LEFT_AGGR)"), ImmutableList.of(DOUBLE)) - .singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("RIGHT_GROUP_BY")) + Optional.of(p.variable("LEFT_HASH")), + Optional.of(p.variable("RIGHT_HASH")))) + .addAggregation(p.variable("AVG", DOUBLE), expression("AVG(LEFT_AGGR)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.variable("LEFT_GROUP_BY"), p.variable("RIGHT_GROUP_BY")) .step(PARTIAL))) .matches(project(ImmutableMap.of( "LEFT_GROUP_BY", PlanMatchPattern.expression("LEFT_GROUP_BY"), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java index 0a0e7c1082041..a4bcc2e7745c5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.block.SortOrder; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; @@ -44,8 +44,8 @@ public void testDoesNotFireNoExchange() tester().assertThat(new PushProjectionThroughExchange()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new LongLiteral("3")), - p.values(p.symbol("a")))) + Assignments.of(p.variable("x"), new LongLiteral("3")), + p.values(p.variable("a")))) .doesNotFire(); } @@ -54,14 +54,14 @@ public void testDoesNotFireNarrowingProjection() { tester().assertThat(new PushProjectionThroughExchange()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); - Symbol c = p.symbol("c"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); return p.project( Assignments.builder() - .put(a, a.toSymbolReference()) - .put(b, b.toSymbolReference()) + .put(a, new SymbolReference(a.getName())) + .put(b, new SymbolReference(b.getName())) .build(), p.exchange(e -> e .addSource(p.values(a, b, c)) @@ -76,11 +76,11 @@ public void testSimpleMultipleInputs() { tester().assertThat(new PushProjectionThroughExchange()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); - Symbol c = p.symbol("c"); - Symbol c2 = p.symbol("c2"); - Symbol x = p.symbol("x"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression c2 = p.variable("c2"); + VariableReferenceExpression x = p.variable("x"); return p.project( Assignments.of( x, new LongLiteral("3"), @@ -112,12 +112,12 @@ public void testPartitioningColumnAndHashWithoutIdentityMappingInProjection() { tester().assertThat(new PushProjectionThroughExchange()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); - Symbol h = p.symbol("h"); - Symbol aTimes5 = p.symbol("a_times_5"); - Symbol bTimes5 = p.symbol("b_times_5"); - Symbol hTimes5 = p.symbol("h_times_5"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression h = p.variable("h"); + VariableReferenceExpression aTimes5 = p.variable("a_times_5"); + VariableReferenceExpression bTimes5 = p.variable("b_times_5"); + VariableReferenceExpression hTimes5 = p.variable("h_times_5"); return p.project( Assignments.builder() .put(aTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("a"), new LongLiteral("5"))) @@ -154,14 +154,14 @@ public void testOrderingColumnsArePreserved() { tester().assertThat(new PushProjectionThroughExchange()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); - Symbol h = p.symbol("h"); - Symbol aTimes5 = p.symbol("a_times_5"); - Symbol bTimes5 = p.symbol("b_times_5"); - Symbol hTimes5 = p.symbol("h_times_5"); - Symbol sortSymbol = p.symbol("sortSymbol"); - OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(sortSymbol), ImmutableMap.of(sortSymbol, SortOrder.ASC_NULLS_FIRST)); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression h = p.variable("h"); + VariableReferenceExpression aTimes5 = p.variable("a_times_5"); + VariableReferenceExpression bTimes5 = p.variable("b_times_5"); + VariableReferenceExpression hTimes5 = p.variable("h_times_5"); + VariableReferenceExpression sortVariable = p.variable("sortVariable"); + OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(sortVariable), ImmutableMap.of(sortVariable, SortOrder.ASC_NULLS_FIRST)); return p.project( Assignments.builder() .put(aTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("a"), new LongLiteral("5"))) @@ -170,10 +170,10 @@ public void testOrderingColumnsArePreserved() .build(), p.exchange(e -> e .addSource( - p.values(a, b, h, sortSymbol)) - .addInputsSet(a, b, h, sortSymbol) + p.values(a, b, h, sortVariable)) + .addInputsSet(a, b, h, sortVariable) .singleDistributionPartitioningScheme( - ImmutableList.of(a, b, h, sortSymbol)) + ImmutableList.of(a, b, h, sortVariable)) .orderingScheme(orderingScheme))); }) .matches( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java index 2b22593cb6fa5..e0677da3253a7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -13,11 +13,12 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; @@ -37,8 +38,8 @@ public void testDoesNotFire() tester().assertThat(new PushProjectionThroughUnion()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new LongLiteral("3")), - p.values(p.symbol("a")))) + Assignments.of(p.variable("x"), new LongLiteral("3")), + p.values(p.variable("a")))) .doesNotFire(); } @@ -47,14 +48,13 @@ public void test() { tester().assertThat(new PushProjectionThroughUnion()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); - Symbol c = p.symbol("c"); - Symbol cTimes3 = p.symbol("c_times_3"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); return p.project( - Assignments.of(cTimes3, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, c.toSymbolReference(), new LongLiteral("3"))), + Assignments.of(p.variable("c_times_3"), new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference(c.getName()), new LongLiteral("3"))), p.union( - ImmutableListMultimap.builder() + ImmutableListMultimap.builder() .put(c, a) .put(c, b) .build(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushTableWriteThroughUnion.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushTableWriteThroughUnion.java index 71cc9e28c11f6..b07f7117bb344 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushTableWriteThroughUnion.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushTableWriteThroughUnion.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; @@ -34,15 +34,15 @@ public void testPushThroughUnion() tester().assertThat(new PushTableWriteThroughUnion()) .on(p -> p.tableWriter( - ImmutableList.of(p.symbol("A", BIGINT), p.symbol("B", BIGINT)), ImmutableList.of("a", "b"), + ImmutableList.of(p.variable("A", BIGINT), p.variable("B", BIGINT)), ImmutableList.of("a", "b"), p.union( - ImmutableListMultimap.builder() - .putAll(p.symbol("A", BIGINT), p.symbol("A1", BIGINT), p.symbol("B2", BIGINT)) - .putAll(p.symbol("B", BIGINT), p.symbol("B1", BIGINT), p.symbol("A2", BIGINT)) + ImmutableListMultimap.builder() + .putAll(p.variable("A", BIGINT), p.variable("A1", BIGINT), p.variable("B2", BIGINT)) + .putAll(p.variable("B", BIGINT), p.variable("B1", BIGINT), p.variable("A2", BIGINT)) .build(), ImmutableList.of( - p.values(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), - p.values(p.symbol("A2", BIGINT), p.symbol("B2", BIGINT)))))) + p.values(p.variable("A1", BIGINT), p.variable("B1", BIGINT)), + p.values(p.variable("A2", BIGINT), p.variable("B2", BIGINT)))))) .matches(union( tableWriter(ImmutableList.of("A1", "B1"), ImmutableList.of("a", "b"), values(ImmutableMap.of("A1", 0, "B1", 1))), tableWriter(ImmutableList.of("B2", "A2"), ImmutableList.of("a", "b"), values(ImmutableMap.of("A2", 0, "B2", 1))))); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java index b8f82067b330a..fbc01441e542b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java @@ -44,8 +44,9 @@ public void testDoesNotFire() TestingTransactionHandle.create(), Optional.empty()), ImmutableList.of(), + ImmutableList.of(), ImmutableMap.of()), - p.symbol("a", BigintType.BIGINT))) + p.variable(p.symbol("a", BigintType.BIGINT)))) .doesNotFire(); } @@ -56,7 +57,7 @@ public void test() .on(p -> p.tableDelete( new SchemaTableName("sch", "tab"), p.values(), - p.symbol("a", BigintType.BIGINT))) + p.variable(p.symbol("a", BigintType.BIGINT)))) .matches( PlanMatchPattern.values(ImmutableMap.of("a", 0))); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java index 34bf8c2982168..85c1b8dc99519 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java @@ -36,7 +36,7 @@ public void testDoesNotFire() p.sample( 0.15, Type.BERNOULLI, - p.values(p.symbol("a")))) + p.values(p.variable("a")))) .doesNotFire(); } @@ -51,7 +51,7 @@ public void test() p.filter( expression("b > 5"), p.values( - ImmutableList.of(p.symbol("a"), p.symbol("b")), + ImmutableList.of(p.variable(p.symbol("a")), p.variable(p.symbol("b"))), ImmutableList.of( constantExpressions(BIGINT, 1, 10), constantExpressions(BIGINT, 2, 11)))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveTrivialFilters.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveTrivialFilters.java index 5c6ec3e07469c..0c46900bf18b5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveTrivialFilters.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveTrivialFilters.java @@ -47,7 +47,7 @@ public void testRemovesFalseFilter() .on(p -> p.filter( p.expression("FALSE"), p.values( - ImmutableList.of(p.symbol("a")), + ImmutableList.of(p.variable(p.symbol("a"))), ImmutableList.of(constantExpressions(BIGINT, 1))))) .matches(values("a")); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java index 1532466c06cbb..c35e841c47e3d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java @@ -29,10 +29,10 @@ public void testDoesNotFire() { tester().assertThat(new RemoveUnreferencedScalarApplyNodes()) .on(p -> p.apply( - Assignments.of(p.symbol("z"), p.expression("x IN (y)")), + Assignments.of(p.variable("z"), p.expression("x IN (y)")), ImmutableList.of(), - p.values(p.symbol("x")), - p.values(p.symbol("y")))) + p.values(p.variable("x")), + p.values(p.variable("y")))) .doesNotFire(); } @@ -43,8 +43,8 @@ public void testEmptyAssignments() .on(p -> p.apply( Assignments.of(), ImmutableList.of(), - p.values(p.symbol("x")), - p.values(p.symbol("y")))) + p.values(p.variable("x")), + p.values(p.variable("y")))) .matches(values("x")); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java index e7ef0fd1dc2c0..9423866219365 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java @@ -30,7 +30,7 @@ public void testRemoveUnreferencedInput() tester().assertThat(new RemoveUnreferencedScalarLateralNodes()) .on(p -> p.lateral( emptyList(), - p.values(p.symbol("x", BigintType.BIGINT)), + p.values(p.variable("x", BigintType.BIGINT)), p.values(emptyList(), ImmutableList.of(emptyList())))) .matches(values("x")); } @@ -42,7 +42,7 @@ public void testRemoveUnreferencedSubquery() .on(p -> p.lateral( emptyList(), p.values(emptyList(), ImmutableList.of(emptyList())), - p.values(p.symbol("x", BigintType.BIGINT)))) + p.values(p.variable("x", BigintType.BIGINT)))) .matches(values("x")); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java index bcb68a317f79d..691eea337790b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java @@ -15,12 +15,12 @@ import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.PlanNodeStatsEstimate; -import com.facebook.presto.cost.SymbolStatsEstimate; +import com.facebook.presto.cost.VariableStatsEstimate; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; @@ -40,6 +40,7 @@ import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_MAX_BROADCAST_TABLE_SIZE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.AUTOMATIC; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.BROADCAST; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; @@ -84,20 +85,20 @@ public void testKeepsOutputSymbols() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.symbol("A1"), p.symbol("A2")), TWO_ROWS), - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1")), TWO_ROWS), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A2")), + p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.variable("A1"), p.variable("A2")), TWO_ROWS), + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1")), TWO_ROWS), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A2")), Optional.empty())) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(5000) - .addSymbolStatistics(ImmutableMap.of( - new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 100), - new Symbol("A2"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .addVariableStatistics(ImmutableMap.of( + new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 100), + new VariableReferenceExpression("A2", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 100))) .build()) .matches(join( INNER, @@ -116,18 +117,18 @@ public void testReplicatesAndFlipsWhenOneTableMuchSmaller() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.symbol("A1")), TWO_ROWS), - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1")), TWO_ROWS), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.variable("A1")), TWO_ROWS), + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1")), TWO_ROWS), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.empty())) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .matches(join( INNER, @@ -145,19 +146,19 @@ public void testRepartitionsWhenRequiredBySession() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.symbol("A1")), TWO_ROWS), - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1")), TWO_ROWS), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.variable("A1")), TWO_ROWS), + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1")), TWO_ROWS), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.empty())) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.PARTITIONED.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 6400, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .matches(join( INNER, @@ -175,18 +176,18 @@ public void testRepartitionsWhenBothTablesEqual() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.symbol("A1")), TWO_ROWS), - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1")), TWO_ROWS), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.variable("A1")), TWO_ROWS), + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1")), TWO_ROWS), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.empty())) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .matches(join( INNER, @@ -204,20 +205,20 @@ public void testReplicatesUnrestrictedWhenRequiredBySession() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.symbol("A1")), TWO_ROWS), - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1")), TWO_ROWS), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.variable("A1")), TWO_ROWS), + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1")), TWO_ROWS), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.empty())) .setSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, "1kB") .setSystemProperty(JOIN_DISTRIBUTION_TYPE, BROADCAST.name()) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .matches(join( INNER, @@ -241,11 +242,11 @@ public void testReplicatedScalarJoinEvenWhereSessionRequiresRepartitioned() PlanNodeStatsEstimate valuesA = PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build(); PlanNodeStatsEstimate valuesB = PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build(); assertReorderJoins() @@ -253,10 +254,10 @@ public void testReplicatedScalarJoinEvenWhereSessionRequiresRepartitioned() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), p.symbol("A1")), // matches isAtMostScalar - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1")), TWO_ROWS), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + p.values(new PlanNodeId("valuesA"), p.variable("A1")), // matches isAtMostScalar + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1")), TWO_ROWS), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.empty())) .overrideStats("valuesA", valuesA) .overrideStats("valuesB", valuesB) @@ -267,10 +268,10 @@ public void testReplicatedScalarJoinEvenWhereSessionRequiresRepartitioned() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1")), TWO_ROWS), - p.values(new PlanNodeId("valuesA"), p.symbol("A1")), // matches isAtMostScalar - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1")), TWO_ROWS), + p.values(new PlanNodeId("valuesA"), p.variable("A1")), // matches isAtMostScalar + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.empty())) .overrideStats("valuesA", valuesA) .overrideStats("valuesB", valuesB) @@ -284,18 +285,18 @@ public void testDoesNotFireForCrossJoin() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.symbol("A1")), TWO_ROWS), - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1")), TWO_ROWS), + p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.variable("A1")), TWO_ROWS), + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1")), TWO_ROWS), ImmutableList.of(), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.empty())) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 100))) .build()) .doesNotFire(); } @@ -307,10 +308,10 @@ public void testDoesNotFireWithNoStats() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.symbol("A1")), TWO_ROWS), - p.values(new PlanNodeId("valuesB"), p.symbol("B1")), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1")), + p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.variable("A1")), TWO_ROWS), + p.values(new PlanNodeId("valuesB"), p.variable("B1")), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1")), Optional.empty())) .overrideStats("valuesA", PlanNodeStatsEstimate.unknown()) .doesNotFire(); @@ -323,10 +324,10 @@ public void testDoesNotFireForNonDeterministicFilter() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), p.symbol("A1")), - p.values(new PlanNodeId("valuesB"), p.symbol("B1")), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + p.values(new PlanNodeId("valuesA"), p.variable("A1")), + p.values(new PlanNodeId("valuesB"), p.variable("B1")), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.of(new ComparisonExpression(LESS_THAN, p.symbol("A1").toSymbolReference(), new FunctionCall(QualifiedName.of("random"), ImmutableList.of()))))) .doesNotFire(); } @@ -340,29 +341,29 @@ public void testPredicatesPushedDown() INNER, p.join( INNER, - p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.symbol("A1")), TWO_ROWS), - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1"), p.symbol("B2")), TWO_ROWS), + p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.variable("A1")), TWO_ROWS), + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1"), p.variable("B2")), TWO_ROWS), ImmutableList.of(), - ImmutableList.of(p.symbol("A1"), p.symbol("B1"), p.symbol("B2")), + ImmutableList.of(p.variable("A1"), p.variable("B1"), p.variable("B2")), Optional.empty()), - p.values(new PlanNodeId("valuesC"), ImmutableList.of(p.symbol("C1")), TWO_ROWS), + p.values(new PlanNodeId("valuesC"), ImmutableList.of(p.variable("C1")), TWO_ROWS), ImmutableList.of( - new EquiJoinClause(p.symbol("B2"), p.symbol("C1"))), - ImmutableList.of(p.symbol("A1")), + new EquiJoinClause(p.variable("B2"), p.variable("C1"))), + ImmutableList.of(p.variable("A1")), Optional.of(new ComparisonExpression(EQUAL, p.symbol("A1").toSymbolReference(), p.symbol("B1").toSymbolReference())))) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 10))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(5) - .addSymbolStatistics(ImmutableMap.of( - new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 5), - new Symbol("B2"), new SymbolStatsEstimate(0, 100, 0, 100, 5))) + .addVariableStatistics(ImmutableMap.of( + new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 5), + new VariableReferenceExpression("B2", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 5))) .build()) .overrideStats("valuesC", PlanNodeStatsEstimate.builder() .setOutputRowCount(1000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("C1", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 100))) .build()) .matches( join( @@ -385,29 +386,29 @@ public void testSmallerJoinFirst() INNER, p.join( INNER, - p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.symbol("A1")), TWO_ROWS), - p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1"), p.symbol("B2")), TWO_ROWS), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1"), p.symbol("B2")), + p.values(new PlanNodeId("valuesA"), ImmutableList.of(p.variable("A1")), TWO_ROWS), + p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.variable("B1"), p.variable("B2")), TWO_ROWS), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1"), p.variable("B2")), Optional.empty()), - p.values(new PlanNodeId("valuesC"), ImmutableList.of(p.symbol("C1")), TWO_ROWS), + p.values(new PlanNodeId("valuesC"), ImmutableList.of(p.variable("C1")), TWO_ROWS), ImmutableList.of( - new EquiJoinClause(p.symbol("B2"), p.symbol("C1"))), - ImmutableList.of(p.symbol("A1")), + new EquiJoinClause(p.variable("B2"), p.variable("C1"))), + ImmutableList.of(p.variable("A1")), Optional.of(new ComparisonExpression(EQUAL, p.symbol("A1").toSymbolReference(), p.symbol("B1").toSymbolReference())))) .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setOutputRowCount(40) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 10))) .build()) .overrideStats("valuesB", PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .addSymbolStatistics(ImmutableMap.of( - new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 10), - new Symbol("B2"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .addVariableStatistics(ImmutableMap.of( + new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 10), + new VariableReferenceExpression("B2", BIGINT), new VariableStatsEstimate(0, 100, 0, 100, 10))) .build()) .overrideStats("valuesC", PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .addSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(99, 199, 0, 100, 100))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("C1", BIGINT), new VariableStatsEstimate(99, 199, 0, 100, 100))) .build()) .matches( join( @@ -429,11 +430,11 @@ public void testReplicatesWhenNotRestricted() PlanNodeStatsEstimate probeSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 10))) .build(); PlanNodeStatsEstimate buildSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000, 10))) .build(); // B table is small enough to be replicated in AUTOMATIC_RESTRICTED mode @@ -443,10 +444,10 @@ public void testReplicatesWhenNotRestricted() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1")), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1")), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1")), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1")), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.empty())) .overrideStats("valuesA", probeSideStatsEstimate) .overrideStats("valuesB", buildSideStatsEstimate) @@ -460,11 +461,11 @@ public void testReplicatesWhenNotRestricted() probeSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(aRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("A1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); buildSideStatsEstimate = PlanNodeStatsEstimate.builder() .setOutputRowCount(bRows) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000d * 10000, 10))) + .addVariableStatistics(ImmutableMap.of(new VariableReferenceExpression("B1", BIGINT), new VariableStatsEstimate(0, 100, 0, 640000d * 10000, 10))) .build(); // B table exceeds AUTOMATIC_RESTRICTED limit therefore it is partitioned @@ -474,10 +475,10 @@ public void testReplicatesWhenNotRestricted() .on(p -> p.join( INNER, - p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1")), - p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1")), - ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))), - ImmutableList.of(p.symbol("A1"), p.symbol("B1")), + p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1")), + p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1")), + ImmutableList.of(new EquiJoinClause(p.variable("A1"), p.variable("B1"))), + ImmutableList.of(p.variable("A1"), p.variable("B1")), Optional.empty())) .overrideStats("valuesA", probeSideStatsEstimate) .overrideStats("valuesB", buildSideStatsEstimate) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java index 7254388d0e02c..91df539787039 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java @@ -41,11 +41,11 @@ public void testNoDistinct() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("count(input1)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output1")), expression("count(input1)"), ImmutableList.of(BIGINT)) .source( p.values( - p.symbol("input1"), - p.symbol("input2"))))) + p.variable("input1"), + p.variable("input2"))))) .doesNotFire(); } @@ -55,12 +55,12 @@ public void testMultipleDistincts() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), expression("count(DISTINCT input2)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output1")), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output2")), expression("count(DISTINCT input2)"), ImmutableList.of(BIGINT)) .source( p.values( - p.symbol("input1"), - p.symbol("input2"))))) + p.variable("input1"), + p.variable("input2"))))) .doesNotFire(); } @@ -70,12 +70,12 @@ public void testMixedDistinctAndNonDistinct() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), expression("count(input2)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output1")), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output2")), expression("count(input2)"), ImmutableList.of(BIGINT)) .source( p.values( - p.symbol("input1"), - p.symbol("input2"))))) + p.variable("input1"), + p.variable("input2"))))) .doesNotFire(); } @@ -85,11 +85,11 @@ public void testDistinctWithFilter() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output"), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output")), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) .source( p.values( - p.symbol("input1"), - p.symbol("input2"))))) + p.variable("input1"), + p.variable("input2"))))) .doesNotFire(); } @@ -99,9 +99,9 @@ public void testSingleAggregation() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output"), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output")), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("input"))))) + p.values(p.variable("input"))))) .matches( aggregation( globalAggregation(), @@ -126,10 +126,10 @@ public void testMultipleAggregations() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2"), expression("sum(DISTINCT input)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output1")), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) + .addAggregation(p.variable(p.symbol("output2")), expression("sum(DISTINCT input)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("input"))))) + p.values(p.variable("input"))))) .matches( aggregation( globalAggregation(), @@ -155,10 +155,10 @@ public void testMultipleInputs() tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() - .addAggregation(p.symbol("output1"), expression("corr(DISTINCT x, y)"), ImmutableList.of(REAL, REAL)) - .addAggregation(p.symbol("output2"), expression("corr(DISTINCT y, x)"), ImmutableList.of(REAL, REAL)) + .addAggregation(p.variable(p.symbol("output1")), expression("corr(DISTINCT x, y)"), ImmutableList.of(REAL, REAL)) + .addAggregation(p.variable(p.symbol("output2")), expression("corr(DISTINCT y, x)"), ImmutableList.of(REAL, REAL)) .source( - p.values(p.symbol("x"), p.symbol("y"))))) + p.values(p.variable("x"), p.variable("y"))))) .matches( aggregation( globalAggregation(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index f163ea527028b..de5fceb8ffbfe 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -66,7 +66,7 @@ public TestSwapAdjacentWindowsBySpecifications() public void doesNotFireOnPlanWithoutWindowFunctions() { tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) - .on(p -> p.values(p.symbol("a"))) + .on(p -> p.values(p.variable("a"))) .doesNotFire(); } @@ -75,11 +75,11 @@ public void doesNotFireOnPlanWithSingleWindowNode() { tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a")), + ImmutableList.of(p.variable(p.symbol("a"))), Optional.empty()), - ImmutableMap.of(p.symbol("avg_1"), + ImmutableMap.of(p.variable(p.symbol("avg_1")), new WindowNode.Function(call("avg", functionHandle, DOUBLE, ImmutableList.of()), frame)), - p.values(p.symbol("a")))) + p.values(p.variable("a")))) .doesNotFire(); } @@ -98,14 +98,14 @@ public void subsetComesFirst() tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a")), + ImmutableList.of(p.variable(p.symbol("a"))), Optional.empty()), - ImmutableMap.of(p.symbol("avg_1", DOUBLE), newWindowNodeFunction(ImmutableList.of(new Symbol("a")))), + ImmutableMap.of(p.variable(p.symbol("avg_1", DOUBLE)), newWindowNodeFunction(ImmutableList.of(new Symbol("a")))), p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a"), p.symbol("b")), + ImmutableList.of(p.variable(p.symbol("a")), p.variable(p.symbol("b"))), Optional.empty()), - ImmutableMap.of(p.symbol("avg_2", DOUBLE), newWindowNodeFunction(ImmutableList.of(new Symbol("b")))), - p.values(p.symbol("a"), p.symbol("b"))))) + ImmutableMap.of(p.variable(p.symbol("avg_2", DOUBLE)), newWindowNodeFunction(ImmutableList.of(new Symbol("b")))), + p.values(p.variable("a"), p.variable("b"))))) .matches( window(windowMatcherBuilder -> windowMatcherBuilder .specification(specificationAB) @@ -124,14 +124,14 @@ public void dependentWindowsAreNotReordered() tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a")), + ImmutableList.of(p.variable(p.symbol("a"))), Optional.empty()), - ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction(ImmutableList.of(new Symbol("avg_2")))), + ImmutableMap.of(p.variable(p.symbol("avg_1")), newWindowNodeFunction(ImmutableList.of(new Symbol("avg_2")))), p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a"), p.symbol("b")), + ImmutableList.of(p.variable(p.symbol("a")), p.variable(p.symbol("b"))), Optional.empty()), - ImmutableMap.of(p.symbol("avg_2"), newWindowNodeFunction(ImmutableList.of(new Symbol("a")))), - p.values(p.symbol("a"), p.symbol("b"))))) + ImmutableMap.of(p.variable(p.symbol("avg_2")), newWindowNodeFunction(ImmutableList.of(new Symbol("a")))), + p.values(p.variable("a"), p.variable("b"))))) .doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java index afbc3ab6b62ba..009f843e9f9a7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java @@ -37,7 +37,7 @@ public class TestTransformCorrelatedScalarAggregationToJoin public void doesNotFireOnPlanWithoutApplyNode() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionManager())) - .on(p -> p.values(p.symbol("a"))) + .on(p -> p.values(p.variable("a"))) .doesNotFire(); } @@ -46,9 +46,9 @@ public void doesNotFireOnCorrelatedWithoutAggregation() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionManager())) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), - p.values(p.symbol("a")))) + ImmutableList.of(p.variable(p.symbol("corr"))), + p.values(p.variable("corr")), + p.values(p.variable("a")))) .doesNotFire(); } @@ -58,8 +58,8 @@ public void doesNotFireOnUncorrelated() tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionManager())) .on(p -> p.lateral( ImmutableList.of(), - p.values(p.symbol("a")), - p.values(p.symbol("b")))) + p.values(p.variable("a")), + p.values(p.variable("b")))) .doesNotFire(); } @@ -68,12 +68,12 @@ public void doesNotFireOnCorrelatedWithNonScalarAggregation() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionManager())) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), + ImmutableList.of(p.variable(p.symbol("corr"))), + p.values(p.variable("corr")), p.aggregation(ab -> ab - .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) - .singleGroupingSet(p.symbol("b"))))) + .source(p.values(p.variable("a"), p.variable("b"))) + .addAggregation(p.variable(p.symbol("sum")), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .singleGroupingSet(p.variable("b"))))) .doesNotFire(); } @@ -82,11 +82,11 @@ public void rewritesOnSubqueryWithoutProjection() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionManager())) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), + ImmutableList.of(p.variable(p.symbol("corr"))), + p.values(p.variable("corr")), p.aggregation(ab -> ab - .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .source(p.values(p.variable("a"), p.variable("b"))) + .addAggregation(p.variable(p.symbol("sum")), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .globalGrouping()))) .matches( project(ImmutableMap.of("sum_1", expression("sum_1"), "corr", expression("corr")), @@ -104,12 +104,12 @@ public void rewritesOnSubqueryWithProjection() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionManager())) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), p.expression("sum + 1")), + ImmutableList.of(p.variable(p.symbol("corr"))), + p.values(p.variable("corr")), + p.project(Assignments.of(p.variable("expr"), p.expression("sum + 1")), p.aggregation(ab -> ab - .source(p.values(p.symbol("a"), p.symbol("b"))) - .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .source(p.values(p.variable("a"), p.variable("b"))) + .addAggregation(p.variable(p.symbol("sum")), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .globalGrouping())))) .matches( project(ImmutableMap.of("corr", expression("corr"), "expr", expression("(\"sum_1\" + 1)")), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java index cd70861189a7c..eac94f7ec1ea6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java @@ -17,7 +17,6 @@ import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; @@ -61,7 +60,7 @@ public class TestTransformCorrelatedScalarSubquery public void doesNotFireOnPlanWithoutLateralNode() { tester().assertThat(rule) - .on(p -> p.values(p.symbol("a"))) + .on(p -> p.values(p.variable("a"))) .doesNotFire(); } @@ -70,9 +69,9 @@ public void doesNotFireOnCorrelatedNonScalar() { tester().assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), - p.values(p.symbol("a")))) + ImmutableList.of(p.variable("corr")), + p.values(p.variable("corr")), + p.values(p.variable("a")))) .doesNotFire(); } @@ -81,9 +80,9 @@ public void doesNotFireOnUncorrelated() { tester().assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(), - p.values(p.symbol("a")), - p.values(ImmutableList.of(p.symbol("b")), ImmutableList.of(constantExpressions(BIGINT, 1))))) + ImmutableList.of(), + p.values(p.variable("a")), + p.values(ImmutableList.of(p.variable("b")), ImmutableList.of(constantExpressions(BIGINT, 1))))) .doesNotFire(); } @@ -92,12 +91,12 @@ public void rewritesOnSubqueryWithoutProjection() { tester().assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), + ImmutableList.of(p.variable("corr")), + p.values(p.variable("corr")), p.enforceSingleRow( p.filter( p.expression("1 = a"), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers - p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS))))) + p.values(ImmutableList.of(p.variable("a")), TWO_ROWS))))) .matches( project( filter( @@ -120,14 +119,14 @@ public void rewritesOnSubqueryWithProjection() { tester().assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), + ImmutableList.of(p.variable("corr")), + p.values(p.variable("corr")), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), p.expression("a * 2")), + Assignments.of(p.variable("a2"), p.expression("a * 2")), p.filter( p.expression("1 = a"), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers - p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS)))))) + p.values(ImmutableList.of(p.variable("a")), TWO_ROWS)))))) .matches( project( filter( @@ -150,16 +149,16 @@ public void rewritesOnSubqueryWithProjectionOnTopEnforceSingleNode() { tester().assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), + ImmutableList.of(p.variable("corr")), + p.values(p.variable("corr")), p.project( - Assignments.of(p.symbol("a3"), p.expression("a2 + 1")), + Assignments.of(p.variable("a3"), p.expression("a2 + 1")), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), p.expression("a * 2")), + Assignments.of(p.variable("a2"), p.expression("a * 2")), p.filter( p.expression("1 = a"), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers - p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS))))))) + p.values(ImmutableList.of(p.variable("a")), TWO_ROWS))))))) .matches( project( filter( @@ -186,12 +185,12 @@ public void rewritesScalarSubquery() { tester().assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), + ImmutableList.of(p.variable("corr")), + p.values(p.variable("corr")), p.enforceSingleRow( p.filter( p.expression("1 = a"), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers - p.values(ImmutableList.of(p.symbol("a")), ONE_ROW))))) + p.values(ImmutableList.of(p.variable("a")), ONE_ROW))))) .matches( lateral( ImmutableList.of("corr"), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java index 79b1cd93b111d..3a94fd9fd2f9e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java @@ -40,7 +40,7 @@ public class TestTransformCorrelatedSingleRowSubqueryToProject public void testDoesNotFire() { tester().assertThat(new TransformCorrelatedSingleRowSubqueryToProject()) - .on(p -> p.values(p.symbol("a"))) + .on(p -> p.values(p.variable("a"))) .doesNotFire(); } @@ -50,17 +50,18 @@ public void testRewrite() tester().assertThat(new TransformCorrelatedSingleRowSubqueryToProject()) .on(p -> p.lateral( - ImmutableList.of(p.symbol("l_nationkey")), + ImmutableList.of(p.variable(p.symbol("l_nationkey"))), p.tableScan(new TableHandle( new ConnectorId("local"), new TpchTableHandle("nation", TINY_SCALE_FACTOR), TestingTransactionHandle.create(), Optional.empty()), ImmutableList.of(p.symbol("l_nationkey")), - ImmutableMap.of(p.symbol("l_nationkey"), new TpchColumnHandle("nationkey", + ImmutableList.of(p.variable(p.symbol("l_nationkey"))), + ImmutableMap.of(p.variable(p.symbol("l_nationkey")), new TpchColumnHandle("nationkey", BIGINT))), p.project( - Assignments.of(p.symbol("l_expr2"), expression("l_nationkey + 1")), + Assignments.of(p.variable("l_expr2"), expression("l_nationkey + 1")), p.values( ImmutableList.of(), ImmutableList.of(ImmutableList.of()))))) @@ -77,9 +78,9 @@ public void testDoesNotFireWithEmptyValuesNode() tester().assertThat(new TransformCorrelatedSingleRowSubqueryToProject()) .on(p -> p.lateral( - ImmutableList.of(p.symbol("a")), - p.values(p.symbol("a")), - p.values(p.symbol("a")))) + ImmutableList.of(p.variable(p.symbol("a"))), + p.values(p.variable("a")), + p.values(p.variable("a")))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java index 5326704f45481..93a9471c176dd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java @@ -38,15 +38,15 @@ public class TestTransformExistsApplyToLateralJoin public void testDoesNotFire() { tester().assertThat(new TransformExistsApplyToLateralNode(tester().getMetadata().getFunctionManager())) - .on(p -> p.values(p.symbol("a"))) + .on(p -> p.values(p.variable("a"))) .doesNotFire(); tester().assertThat(new TransformExistsApplyToLateralNode(tester().getMetadata().getFunctionManager())) .on(p -> p.lateral( - ImmutableList.of(p.symbol("a")), - p.values(p.symbol("a")), - p.values(p.symbol("a")))) + ImmutableList.of(p.variable(p.symbol("a"))), + p.values(p.variable("a")), + p.values(p.variable("a")))) .doesNotFire(); } @@ -56,7 +56,7 @@ public void testRewrite() tester().assertThat(new TransformExistsApplyToLateralNode(tester().getMetadata().getFunctionManager())) .on(p -> p.apply( - Assignments.of(p.symbol("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), + Assignments.of(p.variable("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), ImmutableList.of(), p.values(), p.values())) @@ -75,13 +75,13 @@ public void testRewritesToLimit() tester().assertThat(new TransformExistsApplyToLateralNode(tester().getMetadata().getFunctionManager())) .on(p -> p.apply( - Assignments.of(p.symbol("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), - ImmutableList.of(p.symbol("corr")), - p.values(p.symbol("corr")), + Assignments.of(p.variable("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), + ImmutableList.of(p.variable(p.symbol("corr"))), + p.values(p.variable("corr")), p.project(Assignments.of(), p.filter( expression("corr = column"), - p.values(p.symbol("column")))))) + p.values(p.variable("column")))))) .matches( project(ImmutableMap.of("b", PlanMatchPattern.expression("COALESCE(subquerytrue, false)")), lateral( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java index 47255b4fc74b4..f4c1a3663367d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -46,7 +46,7 @@ public void testDoesNotFireOnNonInPredicateSubquery() { tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) .on(p -> p.apply( - Assignments.of(p.symbol("x"), new ExistsPredicate(new LongLiteral("1"))), + Assignments.of(p.variable("x"), new ExistsPredicate(new LongLiteral("1"))), emptyList(), p.values(), p.values())) @@ -59,13 +59,13 @@ public void testFiresForInPredicate() tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) .on(p -> p.apply( Assignments.of( - p.symbol("x"), + p.variable("x"), new InPredicate( new SymbolReference("y"), new SymbolReference("z"))), emptyList(), - p.values(p.symbol("y")), - p.values(p.symbol("z")))) + p.values(p.variable("y")), + p.values(p.variable("z")))) .matches(node(SemiJoinNode.class, values("y"), values("z"))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java index 9ea4319eb045b..7671ca93d5070 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.JoinNode; import com.google.common.collect.ImmutableList; @@ -38,10 +37,9 @@ public void test() @Test public void testDoesNotFire() { - Symbol symbol = new Symbol("x"); tester() .assertThat(new TransformUncorrelatedLateralToJoin()) - .on(p -> p.lateral(ImmutableList.of(symbol), p.values(symbol), p.values())) + .on(p -> p.lateral(ImmutableList.of(p.variable(p.symbol("x"))), p.values(p.variable("x")), p.values())) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 754b0ca05960d..9abf73d8ae434 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; @@ -34,7 +35,6 @@ import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; -import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TestingConnectorIndexHandle; import com.facebook.presto.sql.planner.TestingConnectorTransactionHandle; @@ -57,10 +57,8 @@ import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; -import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; -import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; @@ -75,7 +73,7 @@ import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; import com.facebook.presto.testing.TestingTransactionHandle; import com.google.common.base.Functions; @@ -97,6 +95,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.relational.Expressions.constant; @@ -123,13 +122,13 @@ public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata) this.metadata = metadata; } - public OutputNode output(List columnNames, List outputs, PlanNode source) + public OutputNode output(List columnNames, List outputs, List variables, PlanNode source) { return new OutputNode( idAllocator.getNextId(), source, columnNames, - outputs); + variables); } public OutputNode output(Consumer outputBuilderConsumer) @@ -143,7 +142,7 @@ public class OutputBuilder { private PlanNode source; private List columnNames = new ArrayList<>(); - private List outputs = new ArrayList<>(); + private List outputVariables = new ArrayList<>(); public OutputBuilder source(PlanNode source) { @@ -151,55 +150,56 @@ public OutputBuilder source(PlanNode source) return this; } - public OutputBuilder column(Symbol symbol) + public OutputBuilder column(VariableReferenceExpression variable, String columnName) { - return column(symbol, symbol.getName()); - } - - public OutputBuilder column(Symbol symbol, String columnName) - { - outputs.add(symbol); + outputVariables.add(variable); columnNames.add(columnName); return this; } protected OutputNode build() { - return new OutputNode(idAllocator.getNextId(), source, columnNames, outputs); + return new OutputNode(idAllocator.getNextId(), source, columnNames, outputVariables); } } - public ValuesNode values(Symbol... columns) + public ValuesNode values() { - return values(idAllocator.getNextId(), columns); + return values(idAllocator.getNextId(), ImmutableList.of(), ImmutableList.of()); } - public ValuesNode values(PlanNodeId id, Symbol... columns) + public ValuesNode values(VariableReferenceExpression... columns) + { + return values(idAllocator.getNextId(), 0, columns); + } + + public ValuesNode values(PlanNodeId id, VariableReferenceExpression... columns) { return values(id, 0, columns); } - public ValuesNode values(int rows, Symbol... columns) + public ValuesNode values(int rows, VariableReferenceExpression... columns) { return values(idAllocator.getNextId(), rows, columns); } - public ValuesNode values(PlanNodeId id, int rows, Symbol... columns) + public ValuesNode values(PlanNodeId id, int rows, VariableReferenceExpression... columns) { + List variables = ImmutableList.copyOf(columns); return values( id, - ImmutableList.copyOf(columns), + variables, nElements(rows, row -> nElements(columns.length, cell -> constantNull(UNKNOWN)))); } - public ValuesNode values(List columns, List> rows) + public ValuesNode values(List variables, List> rows) { - return values(idAllocator.getNextId(), columns, rows); + return values(idAllocator.getNextId(), variables, rows); } - public ValuesNode values(PlanNodeId id, List columns, List> rows) + public ValuesNode values(PlanNodeId id, List variables, List> rows) { - return new ValuesNode(id, columns, rows); + return new ValuesNode(id, variables, rows); } public EnforceSingleRowNode enforceSingleRow(PlanNode source) @@ -212,7 +212,7 @@ public LimitNode limit(long limit, PlanNode source) return new LimitNode(idAllocator.getNextId(), source, limit, false); } - public TopNNode topN(long count, List orderBy, PlanNode source) + public TopNNode topN(long count, List orderBy, PlanNode source) { return new TopNNode( idAllocator.getNextId(), @@ -234,14 +234,14 @@ public ProjectNode project(Assignments assignments, PlanNode source) return new ProjectNode(idAllocator.getNextId(), source, assignments); } - public MarkDistinctNode markDistinct(Symbol markerSymbol, List distinctSymbols, PlanNode source) + public MarkDistinctNode markDistinct(VariableReferenceExpression markerVariable, List distinctVariables, PlanNode source) { - return new MarkDistinctNode(idAllocator.getNextId(), source, markerSymbol, distinctSymbols, Optional.empty()); + return new MarkDistinctNode(idAllocator.getNextId(), source, markerVariable, distinctVariables, Optional.empty()); } - public MarkDistinctNode markDistinct(Symbol markerSymbol, List distinctSymbols, Symbol hashSymbol, PlanNode source) + public MarkDistinctNode markDistinct(VariableReferenceExpression markerVariable, List distinctVariables, VariableReferenceExpression hashVariable, PlanNode source) { - return new MarkDistinctNode(idAllocator.getNextId(), source, markerSymbol, distinctSymbols, Optional.of(hashSymbol)); + return new MarkDistinctNode(idAllocator.getNextId(), source, markerVariable, distinctVariables, Optional.of(hashVariable)); } public FilterNode filter(Expression predicate, PlanNode source) @@ -256,39 +256,44 @@ public FilterNode filter(RowExpression predicate, PlanNode source) public AggregationNode aggregation(Consumer aggregationBuilderConsumer) { - AggregationBuilder aggregationBuilder = new AggregationBuilder(); + AggregationBuilder aggregationBuilder = new AggregationBuilder(getTypes()); aggregationBuilderConsumer.accept(aggregationBuilder); return aggregationBuilder.build(); } public class AggregationBuilder { + private final TypeProvider types; private PlanNode source; - private Map assignments = new HashMap<>(); + private Map assignments = new HashMap<>(); private AggregationNode.GroupingSetDescriptor groupingSets; - private List preGroupedSymbols = new ArrayList<>(); + private List preGroupedVariables = new ArrayList<>(); private Step step = Step.SINGLE; - private Optional hashSymbol = Optional.empty(); - private Optional groupIdSymbol = Optional.empty(); + private Optional hashVariable = Optional.empty(); + private Optional groupIdVariable = Optional.empty(); private Session session = testSessionBuilder().build(); + public AggregationBuilder(TypeProvider types) + { + this.types = types; + } public AggregationBuilder source(PlanNode source) { this.source = source; return this; } - public AggregationBuilder addAggregation(Symbol output, Expression expression, List inputTypes) + public AggregationBuilder addAggregation(VariableReferenceExpression output, Expression expression, List inputTypes) { return addAggregation(output, expression, inputTypes, Optional.empty()); } - public AggregationBuilder addAggregation(Symbol output, Expression expression, List inputTypes, Symbol mask) + public AggregationBuilder addAggregation(VariableReferenceExpression output, Expression expression, List inputTypes, VariableReferenceExpression mask) { return addAggregation(output, expression, inputTypes, Optional.of(mask)); } - private AggregationBuilder addAggregation(Symbol output, Expression expression, List inputTypes, Optional mask) + private AggregationBuilder addAggregation(VariableReferenceExpression output, Expression expression, List inputTypes, Optional mask) { checkArgument(expression instanceof FunctionCall); FunctionCall call = (FunctionCall) expression; @@ -297,12 +302,12 @@ private AggregationBuilder addAggregation(Symbol output, Expression expression, functionHandle, call.getArguments(), call.getFilter(), - call.getOrderBy().map(PlannerUtils::toOrderingScheme), + call.getOrderBy().map(OrderBy::getSortItems).map(sortItems -> toOrderingScheme(sortItems, types)), call.isDistinct(), mask)); } - public AggregationBuilder addAggregation(Symbol output, Aggregation aggregation) + public AggregationBuilder addAggregation(VariableReferenceExpression output, Aggregation aggregation) { assignments.put(output, aggregation); return this; @@ -314,9 +319,9 @@ public AggregationBuilder globalGrouping() return this; } - public AggregationBuilder singleGroupingSet(Symbol... symbols) + public AggregationBuilder singleGroupingSet(VariableReferenceExpression... variables) { - groupingSets(AggregationNode.singleGroupingSet(ImmutableList.copyOf(symbols))); + groupingSets(AggregationNode.singleGroupingSet(ImmutableList.copyOf(variables))); return this; } @@ -327,10 +332,10 @@ public AggregationBuilder groupingSets(AggregationNode.GroupingSetDescriptor gro return this; } - public AggregationBuilder preGroupedSymbols(Symbol... symbols) + public AggregationBuilder preGroupedVariables(VariableReferenceExpression... variables) { - checkState(this.preGroupedSymbols.isEmpty(), "preGroupedSymbols already defined"); - this.preGroupedSymbols = ImmutableList.copyOf(symbols); + checkState(this.preGroupedVariables.isEmpty(), "preGroupedVariables already defined"); + this.preGroupedVariables = ImmutableList.copyOf(variables); return this; } @@ -340,15 +345,15 @@ public AggregationBuilder step(Step step) return this; } - public AggregationBuilder hashSymbol(Symbol hashSymbol) + public AggregationBuilder hashVariable(VariableReferenceExpression hashVariable) { - this.hashSymbol = Optional.of(hashSymbol); + this.hashVariable = Optional.of(hashVariable); return this; } - public AggregationBuilder groupIdSymbol(Symbol groupIdSymbol) + public AggregationBuilder groupIdVariable(VariableReferenceExpression groupIdVariable) { - this.groupIdSymbol = Optional.of(groupIdSymbol); + this.groupIdVariable = Optional.of(groupIdVariable); return this; } @@ -360,65 +365,63 @@ protected AggregationNode build() source, assignments, groupingSets, - preGroupedSymbols, + preGroupedVariables, step, - hashSymbol, - groupIdSymbol); + hashVariable, + groupIdVariable); } } - public ApplyNode apply(Assignments subqueryAssignments, List correlation, PlanNode input, PlanNode subquery) + public ApplyNode apply(Assignments subqueryAssignments, List correlation, PlanNode input, PlanNode subquery) { return new ApplyNode(idAllocator.getNextId(), input, subquery, subqueryAssignments, correlation, ""); } - public AssignUniqueId assignUniqueId(Symbol unique, PlanNode source) + public AssignUniqueId assignUniqueId(VariableReferenceExpression variable, PlanNode source) { - return new AssignUniqueId(idAllocator.getNextId(), source, unique); + return new AssignUniqueId(idAllocator.getNextId(), source, variable); } - public LateralJoinNode lateral(List correlation, PlanNode input, PlanNode subquery) + public LateralJoinNode lateral(List correlation, PlanNode input, PlanNode subquery) { - NullLiteral originSubquery = new NullLiteral(); // does not matter for tests return new LateralJoinNode(idAllocator.getNextId(), input, subquery, correlation, LateralJoinNode.Type.INNER, ""); } - public TableScanNode tableScan(List symbols, Map assignments) + public TableScanNode tableScan(List symbols, List variables, Map assignments) { TableHandle tableHandle = new TableHandle( new ConnectorId("testConnector"), new TestingTableHandle(), TestingTransactionHandle.create(), Optional.empty()); - return tableScan(tableHandle, symbols, assignments, TupleDomain.all(), TupleDomain.all()); + return tableScan(tableHandle, symbols, variables, assignments, TupleDomain.all(), TupleDomain.all()); } - public TableScanNode tableScan( - TableHandle tableHandle, - List symbols, - Map assignments) + public TableScanNode tableScan(TableHandle tableHandle, List symbols, List variables, Map assignments) { - return tableScan(tableHandle, symbols, assignments, TupleDomain.all(), TupleDomain.all()); + return tableScan(tableHandle, symbols, variables, assignments, TupleDomain.all(), TupleDomain.all()); } public TableScanNode tableScan( TableHandle tableHandle, List symbols, - Map assignments, + List variables, + Map assignments, TupleDomain currentConstraint, TupleDomain enforcedConstraint) { return new TableScanNode( idAllocator.getNextId(), tableHandle, - symbols, + variables, assignments, currentConstraint, enforcedConstraint); } - public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode deleteSource, Symbol deleteRowId) + public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode deleteSource, VariableReferenceExpression deleteRowId) { + Symbol deleteRowIdSymbol = new Symbol(deleteRowId.getName()); TableWriterNode.DeleteHandle deleteHandle = new TableWriterNode.DeleteHandle( new TableHandle( new ConnectorId("testConnector"), @@ -447,58 +450,58 @@ public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) { return exchange(builder -> builder.type(ExchangeNode.Type.GATHER) .scope(scope) - .singleDistributionPartitioningScheme(child.getOutputSymbols()) + .singleDistributionPartitioningScheme(child.getOutputVariables()) .addSource(child) - .addInputsSet(child.getOutputSymbols())); + .addInputsSet(child.getOutputVariables())); } public SemiJoinNode semiJoin( - Symbol sourceJoinSymbol, - Symbol filteringSourceJoinSymbol, - Symbol semiJoinOutput, - Optional sourceHashSymbol, - Optional filteringSourceHashSymbol, + VariableReferenceExpression sourceJoinVariable, + VariableReferenceExpression filteringSourceJoinVariable, + VariableReferenceExpression semiJoinOutput, + Optional sourceHashVariable, + Optional filteringSourceHashVariable, PlanNode source, PlanNode filteringSource) { return semiJoin( source, filteringSource, - sourceJoinSymbol, - filteringSourceJoinSymbol, + sourceJoinVariable, + filteringSourceJoinVariable, semiJoinOutput, - sourceHashSymbol, - filteringSourceHashSymbol, + sourceHashVariable, + filteringSourceHashVariable, Optional.empty()); } public SemiJoinNode semiJoin( PlanNode source, PlanNode filteringSource, - Symbol sourceJoinSymbol, - Symbol filteringSourceJoinSymbol, - Symbol semiJoinOutput, - Optional sourceHashSymbol, - Optional filteringSourceHashSymbol, + VariableReferenceExpression sourceJoinVariable, + VariableReferenceExpression filteringSourceJoinVariable, + VariableReferenceExpression semiJoinOutput, + Optional sourceHashVariable, + Optional filteringSourceHashVariable, Optional distributionType) { return new SemiJoinNode( idAllocator.getNextId(), source, filteringSource, - sourceJoinSymbol, - filteringSourceJoinSymbol, + sourceJoinVariable, + filteringSourceJoinVariable, semiJoinOutput, - sourceHashSymbol, - filteringSourceHashSymbol, + sourceHashVariable, + filteringSourceHashVariable, distributionType); } public IndexSourceNode indexSource( TableHandle tableHandle, - Set lookupSymbols, - List outputSymbols, - Map assignments, + Set lookupVariables, + List outputVariables, + Map assignments, TupleDomain effectiveTupleDomain) { return new IndexSourceNode( @@ -508,8 +511,8 @@ public IndexSourceNode indexSource( TestingConnectorTransactionHandle.INSTANCE, TestingConnectorIndexHandle.INSTANCE), tableHandle, - lookupSymbols, - outputSymbols, + lookupVariables, + outputVariables, assignments, effectiveTupleDomain); } @@ -528,7 +531,7 @@ public class ExchangeBuilder private PartitioningScheme partitioningScheme; private OrderingScheme orderingScheme; private List sources = new ArrayList<>(); - private List> inputs = new ArrayList<>(); + private List> inputs = new ArrayList<>(); public ExchangeBuilder type(ExchangeNode.Type type) { @@ -542,31 +545,31 @@ public ExchangeBuilder scope(ExchangeNode.Scope scope) return this; } - public ExchangeBuilder singleDistributionPartitioningScheme(Symbol... outputSymbols) + public ExchangeBuilder singleDistributionPartitioningScheme(VariableReferenceExpression... outputVariables) { - return singleDistributionPartitioningScheme(Arrays.asList(outputSymbols)); + return singleDistributionPartitioningScheme(Arrays.asList(outputVariables)); } - public ExchangeBuilder singleDistributionPartitioningScheme(List outputSymbols) + public ExchangeBuilder singleDistributionPartitioningScheme(List outputVariables) { - return partitioningScheme(new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), outputSymbols)); + return partitioningScheme(new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), outputVariables)); } - public ExchangeBuilder fixedHashDistributionParitioningScheme(List outputSymbols, List partitioningSymbols) + public ExchangeBuilder fixedHashDistributionParitioningScheme(List outputVariables, List partitioningVariables) { return partitioningScheme(new PartitioningScheme(Partitioning.create( FIXED_HASH_DISTRIBUTION, - ImmutableList.copyOf(partitioningSymbols)), - ImmutableList.copyOf(outputSymbols))); + ImmutableList.copyOf(partitioningVariables)), + ImmutableList.copyOf(outputVariables))); } - public ExchangeBuilder fixedHashDistributionParitioningScheme(List outputSymbols, List partitioningSymbols, Symbol hashSymbol) + public ExchangeBuilder fixedHashDistributionParitioningScheme(List outputVariables, List partitioningVariables, VariableReferenceExpression hashVariable) { return partitioningScheme(new PartitioningScheme(Partitioning.create( FIXED_HASH_DISTRIBUTION, - ImmutableList.copyOf(partitioningSymbols)), - ImmutableList.copyOf(outputSymbols), - Optional.of(hashSymbol))); + ImmutableList.copyOf(partitioningVariables)), + ImmutableList.copyOf(outputVariables), + Optional.of(hashVariable))); } public ExchangeBuilder partitioningScheme(PartitioningScheme partitioningScheme) @@ -581,12 +584,12 @@ public ExchangeBuilder addSource(PlanNode source) return this; } - public ExchangeBuilder addInputsSet(Symbol... inputs) + public ExchangeBuilder addInputsSet(VariableReferenceExpression... inputs) { return addInputsSet(Arrays.asList(inputs)); } - public ExchangeBuilder addInputsSet(List inputs) + public ExchangeBuilder addInputsSet(List inputs) { this.inputs.add(inputs); return this; @@ -621,18 +624,18 @@ private JoinNode join(JoinNode.Type joinType, PlanNode left, PlanNode right, Opt left, right, ImmutableList.copyOf(criteria), - ImmutableList.builder() - .addAll(left.getOutputSymbols()) - .addAll(right.getOutputSymbols()) + ImmutableList.builder() + .addAll(left.getOutputVariables()) + .addAll(right.getOutputVariables()) .build(), filter, Optional.empty(), Optional.empty()); } - public JoinNode join(JoinNode.Type type, PlanNode left, PlanNode right, List criteria, List outputSymbols, Optional filter) + public JoinNode join(JoinNode.Type type, PlanNode left, PlanNode right, List criteria, List outputVariables, Optional filter) { - return join(type, left, right, criteria, outputSymbols, filter, Optional.empty(), Optional.empty()); + return join(type, left, right, criteria, outputVariables, filter, Optional.empty(), Optional.empty()); } public JoinNode join( @@ -640,12 +643,12 @@ public JoinNode join( PlanNode left, PlanNode right, List criteria, - List outputSymbols, + List outputVariables, Optional filter, - Optional leftHashSymbol, - Optional rightHashSymbol) + Optional leftHashVariable, + Optional rightHashVariable) { - return join(type, left, right, criteria, outputSymbols, filter, leftHashSymbol, rightHashSymbol, Optional.empty()); + return join(type, left, right, criteria, outputVariables, filter, leftHashVariable, rightHashVariable, Optional.empty()); } public JoinNode join( @@ -653,13 +656,13 @@ public JoinNode join( PlanNode left, PlanNode right, List criteria, - List outputSymbols, + List outputVariables, Optional filter, - Optional leftHashSymbol, - Optional rightHashSymbol, + Optional leftHashVariable, + Optional rightHashVariable, Optional distributionType) { - return new JoinNode(idAllocator.getNextId(), type, left, right, criteria, outputSymbols, filter.map(OriginalExpressionUtils::castToRowExpression), leftHashSymbol, rightHashSymbol, distributionType); + return new JoinNode(idAllocator.getNextId(), type, left, right, criteria, outputVariables, filter.map(OriginalExpressionUtils::castToRowExpression), leftHashVariable, rightHashVariable, distributionType); } public PlanNode indexJoin(IndexJoinNode.Type type, TableScanNode probe, TableScanNode index) @@ -674,21 +677,20 @@ public PlanNode indexJoin(IndexJoinNode.Type type, TableScanNode probe, TableSca Optional.empty()); } - public UnionNode union(ListMultimap outputsToInputs, List sources) + public UnionNode union(ListMultimap outputsToInputs, List sources) { - ImmutableList outputs = outputsToInputs.keySet().stream().collect(toImmutableList()); - return new UnionNode(idAllocator.getNextId(), sources, outputsToInputs, outputs); + return new UnionNode(idAllocator.getNextId(), sources, outputsToInputs); } - public TableWriterNode tableWriter(List columns, List columnNames, PlanNode source) + public TableWriterNode tableWriter(List columns, List columnNames, PlanNode source) { return new TableWriterNode( idAllocator.getNextId(), source, new TestingWriterTarget(), - symbol("partialrows", BIGINT), - symbol("fragment", VARBINARY), - symbol("tablecommitcontext", VARBINARY), + variable("partialrows", BIGINT), + variable("fragment", VARBINARY), + variable("tablecommitcontext", VARBINARY), columns, columnNames, Optional.empty(), @@ -696,6 +698,27 @@ public TableWriterNode tableWriter(List columns, List columnName Optional.empty()); } + public VariableReferenceExpression variable(String name) + { + return variable(symbol(name, BIGINT)); + } + + public VariableReferenceExpression variable(Symbol symbol) + { + return new VariableReferenceExpression(symbol.getName(), symbols.get(symbol)); + } + + public VariableReferenceExpression variable(VariableReferenceExpression variable) + { + return variable(variable.getName(), variable.getType()); + } + + public VariableReferenceExpression variable(String name, Type type) + { + Symbol s = symbol(name, type); + return new VariableReferenceExpression(s.getName(), type); + } + public Symbol symbol(String name) { return symbol(name, BIGINT); @@ -717,7 +740,7 @@ public Symbol symbol(String name, Type type) return symbol; } - public WindowNode window(WindowNode.Specification specification, Map functions, PlanNode source) + public WindowNode window(WindowNode.Specification specification, Map functions, PlanNode source) { return new WindowNode( idAllocator.getNextId(), @@ -729,42 +752,37 @@ public WindowNode window(WindowNode.Specification specification, Map functions, Symbol hashSymbol, PlanNode source) + public WindowNode window(WindowNode.Specification specification, Map functions, VariableReferenceExpression hashVariable, PlanNode source) { return new WindowNode( idAllocator.getNextId(), source, specification, ImmutableMap.copyOf(functions), - Optional.of(hashSymbol), + Optional.of(hashVariable), ImmutableSet.of(), 0); } - public RowNumberNode rowNumber(List partitionBy, Optional maxRowCountPerPartition, Symbol rowNumberSymbol, PlanNode source) + public RowNumberNode rowNumber(List partitionBy, Optional maxRowCountPerPartition, Symbol rowNumberSymbol, VariableReferenceExpression rownNumberVariable, PlanNode source) { return new RowNumberNode( idAllocator.getNextId(), source, partitionBy, - rowNumberSymbol, + rownNumberVariable, maxRowCountPerPartition, Optional.empty()); } - public RemoteSourceNode remoteSourceNode(List fragmentIds, List symbols, ExchangeNode.Type exchangeType) - { - return new RemoteSourceNode(idAllocator.getNextId(), fragmentIds, symbols, Optional.empty(), exchangeType); - } - - public UnnestNode unnest(PlanNode source, List replicateSymbols, Map> unnestSymbols, Optional ordinalitySymbol) + public UnnestNode unnest(PlanNode source, List replicateVariables, Map> unnestVariables, Optional ordinalityVariable) { return new UnnestNode( idAllocator.getNextId(), source, - replicateSymbols, - unnestSymbols, - ordinalitySymbol); + replicateVariables, + unnestVariables, + ordinalityVariable); } public static Expression expression(String sql) @@ -793,6 +811,6 @@ public static List constantExpressions(Type type, Object... value public TypeProvider getTypes() { - return TypeProvider.copyOf(symbols); + return TypeProvider.viewOf(symbols); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index f05de7e0e6bbb..1f110dd656a99 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -140,14 +140,14 @@ public void matches(PlanMatchPattern pattern) formatPlan(plan, types))); } - if (!ImmutableSet.copyOf(plan.getOutputSymbols()).equals(ImmutableSet.copyOf(actual.getOutputSymbols()))) { + if (!ImmutableSet.copyOf(plan.getOutputVariables()).equals(ImmutableSet.copyOf(actual.getOutputVariables()))) { fail(String.format( "%s: output schema of transformed and original plans are not equivalent\n" + "\texpected: %s\n" + "\tactual: %s", rule.getClass().getName(), - plan.getOutputSymbols(), - actual.getOutputSymbols())); + plan.getOutputVariables(), + actual.getOutputVariables())); } inTransaction(session -> { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java index 729c330d3fb02..50e2ff8f26e98 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java @@ -35,9 +35,9 @@ public void testReportWrongMatch() tester.assertThat(new DummyReplaceNodeRule()) .on(p -> p.project( - Assignments.of(p.symbol("y"), expression("x")), + Assignments.of(p.variable("y"), expression("x")), p.values( - ImmutableList.of(p.symbol("x")), + ImmutableList.of(p.variable(p.symbol("x"))), ImmutableList.of(constantExpressions(BIGINT, 1))))) .matches( values(ImmutableList.of("different"), ImmutableList.of())); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchanges.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchanges.java index ea3094bffe313..2645a28528332 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchanges.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchanges.java @@ -17,7 +17,7 @@ import com.facebook.presto.spi.GroupingProperty; import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.block.SortOrder; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.optimizations.ActualProperties.Global; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -32,6 +32,7 @@ import java.util.Optional; import static com.facebook.presto.spi.block.SortOrder.ASC_NULLS_FIRST; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.arbitraryPartition; import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.partitionedOn; @@ -208,7 +209,7 @@ public void testPickLayoutUnpartitionedPreference() public void testPickLayoutPartitionedOnSingle() { Comparator preference = streamingExecutionPreference( - PreferredProperties.partitioned(ImmutableSet.of(symbol("a")))); + PreferredProperties.partitioned(ImmutableSet.of(variable("a")))); List input = ImmutableList.builder() .add(builder() @@ -270,7 +271,7 @@ public void testPickLayoutPartitionedOnSingle() public void testPickLayoutPartitionedOnMultiple() { Comparator preference = streamingExecutionPreference( - PreferredProperties.partitioned(ImmutableSet.of(symbol("a"), symbol("b")))); + PreferredProperties.partitioned(ImmutableSet.of(variable("a"), variable("b")))); List input = ImmutableList.builder() .add(builder() @@ -679,7 +680,7 @@ public void testPickLayoutPartitionedWithGroup() { Comparator preference = streamingExecutionPreference (PreferredProperties.partitionedWithLocal( - ImmutableSet.of(symbol("a")), + ImmutableSet.of(variable("a")), ImmutableList.of(grouped("a")))); List input = ImmutableList.builder() @@ -766,30 +767,30 @@ private static Global streamPartitionedOn(String... columnNames) return Global.streamPartitionedOn(arguments(columnNames)); } - private static ConstantProperty constant(String column) + private static ConstantProperty constant(String column) { - return new ConstantProperty<>(symbol(column)); + return new ConstantProperty<>(variable(column)); } - private static GroupingProperty grouped(String... columns) + private static GroupingProperty grouped(String... columns) { - return new GroupingProperty<>(Lists.transform(Arrays.asList(columns), Symbol::new)); + return new GroupingProperty<>(Lists.transform(Arrays.asList(columns), column -> new VariableReferenceExpression(column, BIGINT))); } - private static SortingProperty sorted(String column, SortOrder order) + private static SortingProperty sorted(String column, SortOrder order) { - return new SortingProperty<>(symbol(column), order); + return new SortingProperty<>(variable(column), order); } - private static Symbol symbol(String name) + private static VariableReferenceExpression variable(String name) { - return new Symbol(name); + return new VariableReferenceExpression(name, BIGINT); } - private static List arguments(String[] columnNames) + private static List arguments(String[] columnNames) { return Arrays.asList(columnNames).stream() - .map(Symbol::new) + .map(column -> new VariableReferenceExpression(column, BIGINT)) .collect(toImmutableList()); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java index 4f121265984e5..25cf226cc8f32 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java @@ -13,16 +13,17 @@ */ package com.facebook.presto.sql.planner.plan; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableCollection; import org.testng.annotations.Test; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static org.testng.Assert.assertTrue; public class TestAssingments { - private final Assignments assignments = Assignments.of(new Symbol("test"), TRUE_LITERAL); + private final Assignments assignments = Assignments.of(new VariableReferenceExpression("test", BIGINT), TRUE_LITERAL); @Test public void testOutputsImmutable() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java index 18e0f8e166a51..48e82991bf7e4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.Serialization; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.Symbol; @@ -57,6 +58,8 @@ import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.sql.relational.Expressions.call; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static io.airlift.configuration.ConfigBinder.configBinder; import static io.airlift.json.JsonBinder.jsonBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static org.testng.Assert.assertEquals; @@ -68,6 +71,9 @@ public class TestWindowNode private Symbol columnA; private Symbol columnB; private Symbol columnC; + private VariableReferenceExpression variableA; + private VariableReferenceExpression variableB; + private VariableReferenceExpression variableC; private final JsonCodec codec; @@ -85,16 +91,20 @@ public void setUp() columnB = symbolAllocator.newSymbol("b", BIGINT); columnC = symbolAllocator.newSymbol("c", BIGINT); + variableA = new VariableReferenceExpression(columnA.getName(), BIGINT); + variableB = new VariableReferenceExpression(columnB.getName(), BIGINT); + variableC = new VariableReferenceExpression(columnC.getName(), BIGINT); + sourceNode = new ValuesNode( newId(), - ImmutableList.of(columnA, columnB, columnC), + ImmutableList.of(variableA, variableB, variableC), ImmutableList.of()); } @Test public void testSerializationRoundtrip() { - Symbol windowSymbol = symbolAllocator.newSymbol("sum", BIGINT); + VariableReferenceExpression windowVariable = symbolAllocator.newVariable("sum", BIGINT); FunctionHandle functionHandle = createTestMetadataManager().getFunctionManager().lookupFunction("sum", fromTypes(BIGINT)); WindowNode.Frame frame = new WindowNode.Frame( RANGE, @@ -107,20 +117,20 @@ public void testSerializationRoundtrip() PlanNodeId id = newId(); WindowNode.Specification specification = new WindowNode.Specification( - ImmutableList.of(columnA), + ImmutableList.of(variableA), Optional.of(new OrderingScheme( - ImmutableList.of(columnB), - ImmutableMap.of(columnB, SortOrder.ASC_NULLS_FIRST)))); + ImmutableList.of(variableB), + ImmutableMap.of(variableB, SortOrder.ASC_NULLS_FIRST)))); CallExpression call = call("sum", functionHandle, BIGINT, new VariableReferenceExpression(columnC.getName(), BIGINT)); - Map functions = ImmutableMap.of(windowSymbol, new WindowNode.Function(call, frame)); - Optional hashSymbol = Optional.of(columnB); - Set prePartitionedInputs = ImmutableSet.of(columnA); + Map functions = ImmutableMap.of(windowVariable, new WindowNode.Function(call, frame)); + Optional hashVariable = Optional.of(variableB); + Set prePartitionedInputs = ImmutableSet.of(variableA); WindowNode windowNode = new WindowNode( id, sourceNode, specification, functions, - hashSymbol, + hashVariable, prePartitionedInputs, 0); @@ -132,7 +142,7 @@ public void testSerializationRoundtrip() assertEquals(actualNode.getSpecification(), windowNode.getSpecification()); assertEquals(actualNode.getWindowFunctions(), windowNode.getWindowFunctions()); assertEquals(actualNode.getFrames(), windowNode.getFrames()); - assertEquals(actualNode.getHashSymbol(), windowNode.getHashSymbol()); + assertEquals(actualNode.getHashVariable(), windowNode.getHashVariable()); assertEquals(actualNode.getPrePartitionedInputs(), windowNode.getPrePartitionedInputs()); assertEquals(actualNode.getPreSortedOrderPrefix(), windowNode.getPreSortedOrderPrefix()); } @@ -152,12 +162,17 @@ private JsonCodec getJsonCodec() binder.install(new HandleJsonModule()); binder.bind(SqlParser.class).toInstance(sqlParser); binder.bind(TypeManager.class).toInstance(typeManager); + configBinder(binder).bindConfig(FeaturesConfig.class); + jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + newSetBinder(binder, Type.class); jsonBinder(binder).addSerializerBinding(Slice.class).to(SliceSerializer.class); jsonBinder(binder).addDeserializerBinding(Slice.class).to(SliceDeserializer.class); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); jsonBinder(binder).addSerializerBinding(Expression.class).to(Serialization.ExpressionSerializer.class); jsonBinder(binder).addDeserializerBinding(Expression.class).to(Serialization.ExpressionDeserializer.class); jsonBinder(binder).addDeserializerBinding(FunctionCall.class).to(Serialization.FunctionCallDeserializer.class); + jsonBinder(binder).addKeySerializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionSerializer.class); + jsonBinder(binder).addKeyDeserializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionDeserializer.class); jsonCodecBinder(binder).bindJsonCodec(WindowNode.class); }; Bootstrap app = new Bootstrap(ImmutableList.of(module)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java index 31114e24ff515..d4d95937c29cf 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java @@ -19,9 +19,9 @@ import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -55,6 +55,7 @@ public class TestValidateAggregationsWithDefaultValues private Metadata metadata; private PlanBuilder builder; private Symbol symbol; + private VariableReferenceExpression variable; private TableScanNode tableScanNode; @BeforeClass @@ -70,8 +71,9 @@ public void setup() TestingTransactionHandle.create(), Optional.of(new TpchTableLayoutHandle(nationTpchTableHandle, TupleDomain.all()))); TpchColumnHandle nationkeyColumnHandle = new TpchColumnHandle("nationkey", BIGINT); - symbol = new Symbol("nationkey"); - tableScanNode = builder.tableScan(nationTableHandle, ImmutableList.of(symbol), ImmutableMap.of(symbol, nationkeyColumnHandle)); + symbol = builder.symbol("nationkey"); + variable = builder.variable(symbol); + tableScanNode = builder.tableScan(nationTableHandle, ImmutableList.of(symbol), ImmutableList.of(variable), ImmutableMap.of(variable, nationkeyColumnHandle)); } @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Final aggregation with default value not separated from partial aggregation by remote hash exchange") @@ -79,10 +81,10 @@ public void testGloballyDistributedFinalAggregationInTheSameStageAsPartialAggreg { PlanNode root = builder.aggregation( af -> af.step(FINAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(builder.aggregation(ap -> ap .step(PARTIAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(tableScanNode)))); validatePlan(root, false); } @@ -92,10 +94,10 @@ public void testSingleNodeFinalAggregationInTheSameStageAsPartialAggregation() { PlanNode root = builder.aggregation( af -> af.step(FINAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(builder.aggregation(ap -> ap .step(PARTIAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(tableScanNode)))); validatePlan(root, true); } @@ -105,10 +107,10 @@ public void testSingleThreadFinalAggregationInTheSameStageAsPartialAggregation() { PlanNode root = builder.aggregation( af -> af.step(FINAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(builder.aggregation(ap -> ap .step(PARTIAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(builder.values())))); validatePlan(root, true); } @@ -116,18 +118,17 @@ public void testSingleThreadFinalAggregationInTheSameStageAsPartialAggregation() @Test public void testGloballyDistributedFinalAggregationSeparatedFromPartialAggregationByRemoteHashExchange() { - Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( af -> af.step(FINAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(builder.exchange(e -> e .type(REPARTITION) .scope(REMOTE_STREAMING) - .fixedHashDistributionParitioningScheme(ImmutableList.of(symbol), ImmutableList.of(symbol)) - .addInputsSet(symbol) + .fixedHashDistributionParitioningScheme(ImmutableList.of(variable), ImmutableList.of(variable)) + .addInputsSet(variable) .addSource(builder.aggregation(ap -> ap .step(PARTIAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(tableScanNode)))))); validatePlan(root, false); } @@ -135,18 +136,17 @@ public void testGloballyDistributedFinalAggregationSeparatedFromPartialAggregati @Test public void testSingleNodeFinalAggregationSeparatedFromPartialAggregationByLocalHashExchange() { - Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( af -> af.step(FINAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(builder.exchange(e -> e .type(REPARTITION) .scope(LOCAL) - .fixedHashDistributionParitioningScheme(ImmutableList.of(symbol), ImmutableList.of(symbol)) - .addInputsSet(symbol) + .fixedHashDistributionParitioningScheme(ImmutableList.of(variable), ImmutableList.of(variable)) + .addInputsSet(variable) .addSource(builder.aggregation(ap -> ap .step(PARTIAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(tableScanNode)))))); validatePlan(root, true); } @@ -154,20 +154,19 @@ public void testSingleNodeFinalAggregationSeparatedFromPartialAggregationByLocal @Test public void testWithPartialAggregationBelowJoin() { - Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( af -> af.step(FINAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(builder.join( INNER, builder.exchange(e -> e .type(REPARTITION) .scope(LOCAL) - .fixedHashDistributionParitioningScheme(ImmutableList.of(symbol), ImmutableList.of(symbol)) - .addInputsSet(symbol) + .fixedHashDistributionParitioningScheme(ImmutableList.of(variable), ImmutableList.of(variable)) + .addInputsSet(variable) .addSource(builder.aggregation(ap -> ap .step(PARTIAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(tableScanNode)))), builder.values()))); validatePlan(root, true); @@ -176,15 +175,14 @@ public void testWithPartialAggregationBelowJoin() @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Final aggregation with default value not separated from partial aggregation by local hash exchange") public void testWithPartialAggregationBelowJoinWithoutSeparatingExchange() { - Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( af -> af.step(FINAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(builder.join( INNER, builder.aggregation(ap -> ap .step(PARTIAL) - .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) + .groupingSets(groupingSets(ImmutableList.of(variable), 2, ImmutableSet.of(0))) .source(tableScanNode)), builder.values()))); validatePlan(root, true); @@ -195,7 +193,7 @@ private void validatePlan(PlanNode root, boolean forceSingleNode) getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - new ValidateAggregationsWithDefaultValues(forceSingleNode).validate(root, session, metadata, SQL_PARSER, TypeProvider.empty(), WarningCollector.NOOP); + new ValidateAggregationsWithDefaultValues(forceSingleNode).validate(root, session, metadata, SQL_PARSER, builder.getTypes(), WarningCollector.NOOP); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateStreamingAggregations.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateStreamingAggregations.java index 3b147d52a18ad..033ad1a814649 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateStreamingAggregations.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateStreamingAggregations.java @@ -68,24 +68,27 @@ public void testValidateSuccessful() validatePlan( p -> p.aggregation( a -> a.step(SINGLE) - .singleGroupingSet(p.symbol("nationkey")) + .singleGroupingSet(p.variable("nationkey")) .source( p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))))); + ImmutableList.of(p.variable(p.symbol("nationkey", BIGINT))), + ImmutableMap.of(p.variable(p.symbol("nationkey", BIGINT)), new TpchColumnHandle("nationkey", BIGINT)))))); validatePlan( p -> p.aggregation( a -> a.step(SINGLE) - .singleGroupingSet(p.symbol("unique"), p.symbol("nationkey")) - .preGroupedSymbols(p.symbol("unique"), p.symbol("nationkey")) + .singleGroupingSet(p.variable("unique"), p.variable("nationkey")) + .preGroupedVariables(p.variable("unique"), p.variable("nationkey")) .source( - p.assignUniqueId(p.symbol("unique"), + p.assignUniqueId( + p.variable(p.symbol("unique")), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT))))))); + ImmutableList.of(p.variable(p.symbol("nationkey", BIGINT))), + ImmutableMap.of(p.variable(p.symbol("nationkey", BIGINT)), new TpchColumnHandle("nationkey", BIGINT))))))); } @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Streaming aggregation with input not grouped on the grouping keys") @@ -94,13 +97,14 @@ public void testValidateFailed() validatePlan( p -> p.aggregation( a -> a.step(SINGLE) - .singleGroupingSet(p.symbol("nationkey")) - .preGroupedSymbols(p.symbol("nationkey")) + .singleGroupingSet(p.variable("nationkey")) + .preGroupedVariables(p.variable("nationkey")) .source( p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))))); + ImmutableList.of(p.variable(p.symbol("nationkey", BIGINT))), + ImmutableMap.of(p.variable(p.symbol("nationkey", BIGINT)), new TpchColumnHandle("nationkey", BIGINT)))))); } private void validatePlan(Function planProvider) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java index 96fba8c3fea52..70f57f5060f81 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java @@ -15,7 +15,7 @@ import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.OutputNode; @@ -25,6 +25,8 @@ import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; +import static com.facebook.presto.spi.type.BigintType.BIGINT; + public class TestVerifyOnlyOneOutputNode { private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); @@ -56,7 +58,7 @@ public void testValidateFailed() idAllocator.getNextId(), ImmutableList.of(), ImmutableList.of()), Assignments.of() ), ImmutableList.of(), ImmutableList.of() - ), new Symbol("a"), + ), new VariableReferenceExpression("a", BIGINT), false), ImmutableList.of(), ImmutableList.of()); new VerifyOnlyOneOutputNode().validate(root, null, null, null, null, WarningCollector.NOOP); diff --git a/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java b/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java index 3f4ec337b4770..e830de9c79085 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java @@ -14,13 +14,13 @@ package com.facebook.presto.type; import com.facebook.presto.RowPagesBuilder; -import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.operator.DriverYieldSignal; import com.facebook.presto.operator.project.PageProcessor; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.DoubleType; @@ -70,7 +70,6 @@ import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; -import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.google.common.collect.Iterables.getOnlyElement; import static java.math.BigInteger.ONE; import static java.math.BigInteger.ZERO; @@ -545,12 +544,11 @@ private Object execute(BaseState state) private static class BaseState { private final MetadataManager metadata = createTestMetadataManager(); - private final Session session = testSessionBuilder().build(); private final Random random = new Random(); protected final Map symbols = new HashMap<>(); protected final Map symbolTypes = new HashMap<>(); - private final Map sourceLayout = new HashMap<>(); + private final Map sourceLayout = new HashMap<>(); protected final List types = new LinkedList<>(); protected Page inputPage; @@ -572,7 +570,7 @@ protected void addSymbol(String name, Type type) Symbol symbol = new Symbol(name); symbols.put(name, symbol); symbolTypes.put(symbol, type); - sourceLayout.put(symbol, types.size()); + sourceLayout.put(new VariableReferenceExpression(name, type), types.size()); types.add(type); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/predicate/TupleDomain.java b/presto-spi/src/main/java/com/facebook/presto/spi/predicate/TupleDomain.java index dd2d365503927..13c68c5fb8189 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/predicate/TupleDomain.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/predicate/TupleDomain.java @@ -14,6 +14,7 @@ package com.facebook.presto.spi.predicate; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.type.Type; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -101,6 +102,22 @@ public static Optional> extractFixedValues(TupleDomain .collect(toLinkedMap(Map.Entry::getKey, entry -> new NullableValue(entry.getValue().getType(), entry.getValue().getNullableSingleValue())))); } + /** + * Extract all column constraints that require exactly one value or only null in their respective Domains. + * Returns an empty Optional if the Domain is none. + */ + public static Optional> extractFixedValuesToConstantExpressions(TupleDomain tupleDomain) + { + if (!tupleDomain.getDomains().isPresent()) { + return Optional.empty(); + } + + return Optional.of(tupleDomain.getDomains().get() + .entrySet().stream() + .filter(entry -> entry.getValue().isNullableSingleValue()) + .collect(toLinkedMap(Map.Entry::getKey, entry -> new ConstantExpression(entry.getValue().getNullableSingleValue(), entry.getValue().getType())))); + } + /** * Convert a map of columns to values into the TupleDomain which requires * those columns to be fixed to those values. Null is allowed as a fixed value. diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/relation/ConstantExpression.java b/presto-spi/src/main/java/com/facebook/presto/spi/relation/ConstantExpression.java index b1a53ef96abce..f808e0afdd9f9 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/relation/ConstantExpression.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/relation/ConstantExpression.java @@ -59,6 +59,11 @@ public Object getValue() return value; } + public boolean isNull() + { + return value == null; + } + @Override @JsonProperty public Type getType() diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/relation/VariableReferenceExpression.java b/presto-spi/src/main/java/com/facebook/presto/spi/relation/VariableReferenceExpression.java index 6caaa21cf4511..b73f41db67d50 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/relation/VariableReferenceExpression.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/relation/VariableReferenceExpression.java @@ -26,14 +26,15 @@ @Immutable public final class VariableReferenceExpression extends RowExpression + implements Comparable { private final String name; private final Type type; @JsonCreator public VariableReferenceExpression( - @JsonProperty String name, - @JsonProperty Type type) + @JsonProperty("name") String name, + @JsonProperty("type") Type type) { this.name = requireNonNull(name, "name is null"); this.type = requireNonNull(type, "type is null"); @@ -82,4 +83,14 @@ public boolean equals(Object obj) VariableReferenceExpression other = (VariableReferenceExpression) obj; return Objects.equals(this.name, other.name) && Objects.equals(this.type, other.type); } + + @Override + public int compareTo(VariableReferenceExpression o) + { + int nameComparison = name.compareTo(o.name); + if (nameComparison != 0) { + return nameComparison; + } + return type.getTypeSignature().toString().compareTo(o.type.getTypeSignature().toString()); + } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java index 52a4f1a9a1965..493f660d8e537 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java @@ -16,8 +16,8 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Plan; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.testing.QueryRunner; @@ -72,11 +72,11 @@ private static List getEstimatedValuesInternal(List metr private static StatsContext buildStatsContext(Plan queryPlan, OutputNode outputNode) { - ImmutableMap.Builder columnSymbols = ImmutableMap.builder(); + ImmutableMap.Builder columnVariables = ImmutableMap.builder(); for (int columnId = 0; columnId < outputNode.getColumnNames().size(); ++columnId) { - columnSymbols.put(outputNode.getColumnNames().get(columnId), outputNode.getOutputSymbols().get(columnId)); + columnVariables.put(outputNode.getColumnNames().get(columnId), outputNode.getOutputVariables().get(columnId)); } - return new StatsContext(columnSymbols.build(), queryPlan.getTypes()); + return new StatsContext(columnVariables.build()); } private static List getActualValues(List metrics, String query, QueryRunner runner) diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metrics.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metrics.java index 9039054b7b515..c5dfba660906e 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metrics.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metrics.java @@ -14,7 +14,7 @@ package com.facebook.presto.tests.statistics; import com.facebook.presto.cost.PlanNodeStatsEstimate; -import com.facebook.presto.cost.SymbolStatsEstimate; +import com.facebook.presto.cost.VariableStatsEstimate; import java.util.Optional; import java.util.OptionalDouble; @@ -60,7 +60,7 @@ public static Metric nullsFraction(String columnName) @Override public OptionalDouble getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext) { - return asOptional(getSymbolStatistics(planNodeStatsEstimate, columnName, statsContext).getNullsFraction()); + return asOptional(getVariableStatistics(planNodeStatsEstimate, columnName, statsContext).getNullsFraction()); } @Override @@ -90,7 +90,7 @@ public static Metric distinctValuesCount(String columnName) @Override public OptionalDouble getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext) { - return asOptional(getSymbolStatistics(planNodeStatsEstimate, columnName, statsContext).getDistinctValuesCount()); + return asOptional(getVariableStatistics(planNodeStatsEstimate, columnName, statsContext).getDistinctValuesCount()); } @Override @@ -120,7 +120,7 @@ public static Metric lowValue(String columnName) @Override public OptionalDouble getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext) { - double lowValue = getSymbolStatistics(planNodeStatsEstimate, columnName, statsContext).getLowValue(); + double lowValue = getVariableStatistics(planNodeStatsEstimate, columnName, statsContext).getLowValue(); if (isInfinite(lowValue)) { return OptionalDouble.empty(); } @@ -158,7 +158,7 @@ public static Metric highValue(String columnName) @Override public OptionalDouble getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext) { - double highValue = getSymbolStatistics(planNodeStatsEstimate, columnName, statsContext).getHighValue(); + double highValue = getVariableStatistics(planNodeStatsEstimate, columnName, statsContext).getHighValue(); if (isInfinite(highValue)) { return OptionalDouble.empty(); } @@ -189,9 +189,9 @@ public String toString() }; } - private static SymbolStatsEstimate getSymbolStatistics(PlanNodeStatsEstimate planNodeStatsEstimate, String columnName, StatsContext statsContext) + private static VariableStatsEstimate getVariableStatistics(PlanNodeStatsEstimate planNodeStatsEstimate, String columnName, StatsContext statsContext) { - return planNodeStatsEstimate.getSymbolStatistics(statsContext.getSymbolForColumn(columnName)); + return planNodeStatsEstimate.getVariableStatistics(statsContext.getVariableForColumn(columnName)); } private static OptionalDouble asOptional(double value) diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatsContext.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatsContext.java index 4fec6fc4d2bbf..2cad076cec718 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatsContext.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatsContext.java @@ -13,35 +13,25 @@ */ package com.facebook.presto.tests.statistics; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableMap; import java.util.Map; import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; public class StatsContext { - private final Map columnSymbols; - private final TypeProvider types; + private final Map columnVariables; - public StatsContext(Map columnSymbols, TypeProvider types) + public StatsContext(Map columnVariables) { - this.columnSymbols = ImmutableMap.copyOf(columnSymbols); - this.types = requireNonNull(types, "symbolTypes is null"); + this.columnVariables = ImmutableMap.copyOf(columnVariables); } - public Symbol getSymbolForColumn(String columnName) + public VariableReferenceExpression getVariableForColumn(String columnName) { - checkArgument(columnSymbols.containsKey(columnName), "no symbol found for column '" + columnName + "'"); - return columnSymbols.get(columnName); - } - - public Type getTypeForSymbol(Symbol symbol) - { - return types.get(symbol); + checkArgument(columnVariables.containsKey(columnName), "no variable found for column '" + columnName + "'"); + return columnVariables.get(columnName); } }