diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java index 428640b5fc556..7915157f14650 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java @@ -14,36 +14,50 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.Type; import com.facebook.presto.expressions.RowExpressionRewriter; import com.facebook.presto.expressions.RowExpressionTreeRewriter; +import com.facebook.presto.metadata.CastType; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.relational.FunctionResolution; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.util.List; import java.util.Set; import static com.facebook.presto.SystemSessionProperties.isRemoveMapCastEnabled; import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TypeUtils.readNativeValue; import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.castToInteger; +import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.tryCast; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; /** - * Remove cast on map if possible. Currently it only supports subscript and element_at function, and only works when map key is of type integer and index is bigint. For example: - * Input: cast(feature as map)[key], where feature is of type map and key is of type bigint + * Remove cast on map if possible. Currently it only supports subscript, element_at and map_subset function, and only works when map key is of type integer and index is bigint. For example: + * Input: cast(feature as map)[key], where feature is of type map and key is of type bigint * Output: feature[cast(key as integer)] - * - * Input: element_at(cast(feature as map), key), where feature is of type map and key is of type bigint + *

+ * Input: element_at(cast(feature as map), key), where feature is of type map and key is of type bigint * Output: element_at(feature, try_cast(key as integer)) - * + *

+ * Input: map_subset(cast(feature as map), array[k1, k2]) where feature is of type map and key is of type array + * Output: cast(map_subset(feature, array[try_cast(k1 as integer), try_cast(k2 as integer)]) as map) + * When k1, or k2 is out of integer range, try_cast will return NULL, where map_subset will not return values for this key, which is the same behavior for both input and output + *

* Notice that here when it's accessing the map using subscript function, we use CAST function in index, and when it's element_at function, we use TRY_CAST function, so that * when the key is out of integer range, for feature[key] it will fail both with and without optimization, fail with map key not exists before optimization and with cast failure after optimization * when the key is out of integer range, for element_at(feature, key) it will return NULL both before and after optimization @@ -101,7 +115,8 @@ private RemoveMapCastRewriter(FunctionAndTypeManager functionAndTypeManager) @Override public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter treeRewriter) { - if ((functionResolution.isSubscriptFunction(node.getFunctionHandle()) || functionResolution.isElementAtFunction(node.getFunctionHandle())) && node.getArguments().get(0) instanceof CallExpression + if ((functionResolution.isSubscriptFunction(node.getFunctionHandle()) || functionResolution.isElementAtFunction(node.getFunctionHandle()) || functionResolution.isMapSubSetFunction(node.getFunctionHandle())) + && node.getArguments().get(0) instanceof CallExpression && functionResolution.isCastFunction(((CallExpression) node.getArguments().get(0)).getFunctionHandle()) && ((CallExpression) node.getArguments().get(0)).getArguments().get(0).getType() instanceof MapType) { CallExpression castExpression = (CallExpression) node.getArguments().get(0); @@ -116,18 +131,42 @@ public RowExpression rewriteCall(CallExpression node, Void context, RowExpressio RowExpression newIndex = castToInteger(functionAndTypeManager, node.getArguments().get(1)); return call(SUBSCRIPT.name(), functionResolution.subscriptFunction(castInput.getType(), newIndex.getType()), node.getType(), castInput, newIndex); } - else { + else if (functionResolution.isElementAtFunction(node.getFunctionHandle())) { RowExpression newIndex = tryCast(functionAndTypeManager, node.getArguments().get(1), INTEGER); return call(functionAndTypeManager, "element_at", node.getType(), castInput, newIndex); } + else if (functionResolution.isMapSubSetFunction(node.getFunctionHandle())) { + RowExpression newKeyArray = null; + if (node.getArguments().get(1) instanceof CallExpression && functionResolution.isArrayConstructor(((CallExpression) node.getArguments().get(1)).getFunctionHandle())) { + CallExpression arrayConstruct = (CallExpression) node.getArguments().get(1); + List newArguments = arrayConstruct.getArguments().stream().map(x -> tryCast(functionAndTypeManager, x, INTEGER)).collect(toImmutableList()); + newKeyArray = call(functionAndTypeManager, "array_constructor", new ArrayType(INTEGER), newArguments); + } + else if (node.getArguments().get(1) instanceof ConstantExpression) { + ConstantExpression constantArray = (ConstantExpression) node.getArguments().get(1); + checkState(constantArray.getValue() instanceof Block && constantArray.getType() instanceof ArrayType); + Block arrayValue = (Block) constantArray.getValue(); + Type arrayElementType = ((ArrayType) constantArray.getType()).getElementType(); + ImmutableList.Builder arguments = ImmutableList.builder(); + for (int i = 0; i < arrayValue.getPositionCount(); ++i) { + ConstantExpression mapKey = constant(readNativeValue(arrayElementType, arrayValue, i), arrayElementType); + arguments.add(tryCast(functionAndTypeManager, mapKey, INTEGER)); + } + newKeyArray = call(functionAndTypeManager, "array_constructor", new ArrayType(INTEGER), arguments.build()); + } + if (newKeyArray != null) { + CallExpression mapSubset = call(functionAndTypeManager, "map_subset", castInput.getType(), castInput, newKeyArray); + return call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, mapSubset.getType(), node.getType()), node.getType(), mapSubset); + } + } } } return null; } - private static boolean canRemoveMapCast(Type fromKeyType, Type fromValueType, Type toKeyType, Type toValueType, Type indexType) + private static boolean canRemoveMapCast(Type fromKeyType, Type fromValueType, Type toKeyType, Type toValueType, Type subsetKeysType) { - return fromValueType.equals(toValueType) && fromKeyType.equals(INTEGER) && toKeyType.equals(BIGINT) && indexType.equals(BIGINT); + return fromValueType.equals(toValueType) && fromKeyType.equals(INTEGER) && toKeyType.equals(BIGINT) && (subsetKeysType.equals(BIGINT) || (subsetKeysType instanceof ArrayType && ((ArrayType) subsetKeysType).getElementType().equals(BIGINT))); } } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java index 88b0c4a771f56..e133149c1713d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java @@ -72,4 +72,44 @@ public void testElementAtCast() ImmutableMap.of("a", expression("element_at(feature, try_cast(key as integer))")), values("feature", "key"))); } + + @Test + public void testMapSubSet() + { + tester().assertThat( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + .setSystemProperty(REMOVE_MAP_CAST, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", DOUBLE); + VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); + VariableReferenceExpression key = p.variable("key", BIGINT); + return p.project( + assignment(a, p.rowExpression("map_subset(cast(feature as map), array[key])")), + p.values(feature, key)); + }) + .matches( + project( + ImmutableMap.of("a", expression("cast(map_subset(feature, array[try_cast(key as integer)]) as map)")), + values("feature", "key"))); + } + + @Test + public void testMapSubSetConstantArray() + { + tester().assertThat( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + .setSystemProperty(REMOVE_MAP_CAST, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", DOUBLE); + VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); + VariableReferenceExpression key = p.variable("key", BIGINT); + return p.project( + assignment(a, p.rowExpression("map_subset(cast(feature as map), array[cast(1 as bigint)])")), + p.values(feature, key)); + }) + .matches( + project( + ImmutableMap.of("a", expression("cast(map_subset(feature, array[1]) as map)")), + values("feature", "key"))); + } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index aca6b3add4e05..fd17cd6544c8d 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -7929,6 +7929,16 @@ public void testRemoveMapCast() ".*Out of range for integer.*"); assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), cast('2' as varchar)), (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), '4')) t(feature, key)", "values 0.5, 0.1"); + + Session disableOptimization = Session.builder(getSession()) + .setSystemProperty(REMOVE_MAP_CAST, "false") + .build(); + assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[key]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", disableOptimization); + assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[1, cast(3 as bigint)]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", disableOptimization); + assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[1, 3, key]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", disableOptimization); + assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[key]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", disableOptimization); + assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[1, 400000000000]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", disableOptimization); + assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[key, 2, 400000000000]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", disableOptimization); } // Test to guardrail problems in constraint framework mentioned in https://github.com/prestodb/presto/pull/22171