Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,50 @@
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.CastType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.List;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.isRemoveMapCastEnabled;
import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.TypeUtils.readNativeValue;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.castToInteger;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.tryCast;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

/**
* Remove cast on map if possible. Currently it only supports subscript and element_at function, and only works when map key is of type integer and index is bigint. For example:
* Input: cast(feature as map<bigint, float>)[key], where feature is of type map<integer, float> and key is of type bigint
* Remove cast on map if possible. Currently it only supports subscript, element_at and map_subset function, and only works when map key is of type integer and index is bigint. For example:
* Input: cast(feature as map<bigint, double>)[key], where feature is of type map<integer, float> and key is of type bigint
* Output: feature[cast(key as integer)]
*
* Input: element_at(cast(feature as map<bigint, float>), key), where feature is of type map<integer, float> and key is of type bigint
* <p>
* Input: element_at(cast(feature as map<bigint, double>), key), where feature is of type map<integer, double> and key is of type bigint
* Output: element_at(feature, try_cast(key as integer))
*
* <p>
* Input: map_subset(cast(feature as map<bigint, double>), array[k1, k2]) where feature is of type map<integer, double> and key is of type array<bigint>
* Output: cast(map_subset(feature, array[try_cast(k1 as integer), try_cast(k2 as integer)]) as map<bigint, double>)
* When k1, or k2 is out of integer range, try_cast will return NULL, where map_subset will not return values for this key, which is the same behavior for both input and output
* <p>
* Notice that here when it's accessing the map using subscript function, we use CAST function in index, and when it's element_at function, we use TRY_CAST function, so that
* when the key is out of integer range, for feature[key] it will fail both with and without optimization, fail with map key not exists before optimization and with cast failure after optimization
* when the key is out of integer range, for element_at(feature, key) it will return NULL both before and after optimization
Expand Down Expand Up @@ -101,7 +115,8 @@ private RemoveMapCastRewriter(FunctionAndTypeManager functionAndTypeManager)
@Override
public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
if ((functionResolution.isSubscriptFunction(node.getFunctionHandle()) || functionResolution.isElementAtFunction(node.getFunctionHandle())) && node.getArguments().get(0) instanceof CallExpression
if ((functionResolution.isSubscriptFunction(node.getFunctionHandle()) || functionResolution.isElementAtFunction(node.getFunctionHandle()) || functionResolution.isMapSubSetFunction(node.getFunctionHandle()))
&& node.getArguments().get(0) instanceof CallExpression
&& functionResolution.isCastFunction(((CallExpression) node.getArguments().get(0)).getFunctionHandle())
&& ((CallExpression) node.getArguments().get(0)).getArguments().get(0).getType() instanceof MapType) {
CallExpression castExpression = (CallExpression) node.getArguments().get(0);
Expand All @@ -116,18 +131,42 @@ public RowExpression rewriteCall(CallExpression node, Void context, RowExpressio
RowExpression newIndex = castToInteger(functionAndTypeManager, node.getArguments().get(1));
return call(SUBSCRIPT.name(), functionResolution.subscriptFunction(castInput.getType(), newIndex.getType()), node.getType(), castInput, newIndex);
}
else {
else if (functionResolution.isElementAtFunction(node.getFunctionHandle())) {
RowExpression newIndex = tryCast(functionAndTypeManager, node.getArguments().get(1), INTEGER);
return call(functionAndTypeManager, "element_at", node.getType(), castInput, newIndex);
}
else if (functionResolution.isMapSubSetFunction(node.getFunctionHandle())) {
RowExpression newKeyArray = null;
if (node.getArguments().get(1) instanceof CallExpression && functionResolution.isArrayConstructor(((CallExpression) node.getArguments().get(1)).getFunctionHandle())) {
CallExpression arrayConstruct = (CallExpression) node.getArguments().get(1);
List<RowExpression> newArguments = arrayConstruct.getArguments().stream().map(x -> tryCast(functionAndTypeManager, x, INTEGER)).collect(toImmutableList());
newKeyArray = call(functionAndTypeManager, "array_constructor", new ArrayType(INTEGER), newArguments);
}
else if (node.getArguments().get(1) instanceof ConstantExpression) {
ConstantExpression constantArray = (ConstantExpression) node.getArguments().get(1);
checkState(constantArray.getValue() instanceof Block && constantArray.getType() instanceof ArrayType);
Block arrayValue = (Block) constantArray.getValue();
Type arrayElementType = ((ArrayType) constantArray.getType()).getElementType();
ImmutableList.Builder<RowExpression> arguments = ImmutableList.builder();
for (int i = 0; i < arrayValue.getPositionCount(); ++i) {
ConstantExpression mapKey = constant(readNativeValue(arrayElementType, arrayValue, i), arrayElementType);
arguments.add(tryCast(functionAndTypeManager, mapKey, INTEGER));
}
newKeyArray = call(functionAndTypeManager, "array_constructor", new ArrayType(INTEGER), arguments.build());
}
if (newKeyArray != null) {
CallExpression mapSubset = call(functionAndTypeManager, "map_subset", castInput.getType(), castInput, newKeyArray);
return call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, mapSubset.getType(), node.getType()), node.getType(), mapSubset);
}
}
}
}
return null;
}

private static boolean canRemoveMapCast(Type fromKeyType, Type fromValueType, Type toKeyType, Type toValueType, Type indexType)
private static boolean canRemoveMapCast(Type fromKeyType, Type fromValueType, Type toKeyType, Type toValueType, Type subsetKeysType)
{
return fromValueType.equals(toValueType) && fromKeyType.equals(INTEGER) && toKeyType.equals(BIGINT) && indexType.equals(BIGINT);
return fromValueType.equals(toValueType) && fromKeyType.equals(INTEGER) && toKeyType.equals(BIGINT) && (subsetKeysType.equals(BIGINT) || (subsetKeysType instanceof ArrayType && ((ArrayType) subsetKeysType).getElementType().equals(BIGINT)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,44 @@ public void testElementAtCast()
ImmutableMap.of("a", expression("element_at(feature, try_cast(key as integer))")),
values("feature", "key")));
}

@Test
public void testMapSubSet()
{
tester().assertThat(
ImmutableSet.<Rule<?>>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build())
.setSystemProperty(REMOVE_MAP_CAST, "true")
.on(p -> {
VariableReferenceExpression a = p.variable("a", DOUBLE);
VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE));
VariableReferenceExpression key = p.variable("key", BIGINT);
return p.project(
assignment(a, p.rowExpression("map_subset(cast(feature as map<bigint, double>), array[key])")),
p.values(feature, key));
})
.matches(
project(
ImmutableMap.of("a", expression("cast(map_subset(feature, array[try_cast(key as integer)]) as map<bigint, double>)")),
values("feature", "key")));
}

@Test
public void testMapSubSetConstantArray()
{
tester().assertThat(
ImmutableSet.<Rule<?>>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build())
.setSystemProperty(REMOVE_MAP_CAST, "true")
.on(p -> {
VariableReferenceExpression a = p.variable("a", DOUBLE);
VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE));
VariableReferenceExpression key = p.variable("key", BIGINT);
return p.project(
assignment(a, p.rowExpression("map_subset(cast(feature as map<bigint, double>), array[cast(1 as bigint)])")),
p.values(feature, key));
})
.matches(
project(
ImmutableMap.of("a", expression("cast(map_subset(feature, array[1]) as map<bigint, double>)")),
values("feature", "key")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7929,6 +7929,16 @@ public void testRemoveMapCast()
".*Out of range for integer.*");
assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), cast('2' as varchar)), (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), '4')) t(feature, key)",
"values 0.5, 0.1");

Session disableOptimization = Session.builder(getSession())
.setSystemProperty(REMOVE_MAP_CAST, "false")
.build();
assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[key]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", disableOptimization);
assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[1, cast(3 as bigint)]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", disableOptimization);
assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[1, 3, key]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", disableOptimization);
assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[key]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", disableOptimization);
assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[1, 400000000000]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", disableOptimization);
assertQueryWithSameQueryRunner(enableOptimization, "select map_subset(feature, array[key, 2, 400000000000]) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", disableOptimization);
}

// Test to guardrail problems in constraint framework mentioned in https://github.com/prestodb/presto/pull/22171
Expand Down
Loading