diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CommonSubExpressionRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CommonSubExpressionRewriter.java index 029611ccdd9f3..6a8dfbac7e57d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CommonSubExpressionRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CommonSubExpressionRewriter.java @@ -13,6 +13,10 @@ */ package com.facebook.presto.sql.gen; +import com.facebook.presto.bytecode.BytecodeBlock; +import com.facebook.presto.bytecode.ClassDefinition; +import com.facebook.presto.bytecode.FieldDefinition; +import com.facebook.presto.bytecode.Variable; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.InputReferenceExpression; @@ -25,6 +29,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; +import com.google.common.primitives.Primitives; import java.util.ArrayList; import java.util.Collection; @@ -35,6 +40,10 @@ import java.util.Set; import java.util.stream.Collectors; +import static com.facebook.presto.bytecode.Access.PRIVATE; +import static com.facebook.presto.bytecode.Access.a; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantBoolean; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.BIND; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN; import static com.facebook.presto.sql.relational.Expressions.subExpressions; @@ -392,4 +401,62 @@ public Integer visitSpecialForm(SpecialFormExpression specialForm, Void collect) return level; } } + + static class CommonSubExpressionFields + { + private final FieldDefinition evaluatedField; + private final FieldDefinition resultField; + private final Class resultType; + private final String methodName; + + public CommonSubExpressionFields(FieldDefinition evaluatedField, FieldDefinition resultField, Class resultType, String methodName) + { + this.evaluatedField = evaluatedField; + this.resultField = resultField; + this.resultType = resultType; + this.methodName = methodName; + } + + public FieldDefinition getEvaluatedField() + { + return evaluatedField; + } + + public FieldDefinition getResultField() + { + return resultField; + } + + public String getMethodName() + { + return methodName; + } + + public Class getResultType() + { + return resultType; + } + + public static Map declareCommonSubExpressionFields(ClassDefinition classDefinition, Map> commonSubExpressionsByLevel) + { + ImmutableMap.Builder fields = ImmutableMap.builder(); + commonSubExpressionsByLevel.values().stream().map(Map::values).flatMap(Collection::stream).forEach(variable -> { + Class type = Primitives.wrap(variable.getType().getJavaType()); + fields.put(variable, new CommonSubExpressionFields( + classDefinition.declareField(a(PRIVATE), variable.getName() + "Evaluated", boolean.class), + classDefinition.declareField(a(PRIVATE), variable.getName() + "Result", type), + type, + "get" + variable.getName())); + }); + return fields.build(); + } + + public static void initializeCommonSubExpressionFields(Collection cseFields, Variable thisVariable, BytecodeBlock body) + { + cseFields.forEach(fields -> { + body.append(thisVariable.setField(fields.getEvaluatedField(), constantBoolean(false))); + body.append(thisVariable.setField(fields.getResultField(), constantNull(fields.getResultType()))); + }); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java index 1a876b14c3f8f..8d7ac11eef08d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.gen; +import com.facebook.airlift.log.Logger; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; @@ -40,32 +41,51 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.CompiledLambda; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Primitives; import io.airlift.slice.Slice; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantBoolean; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.or; import static com.facebook.presto.bytecode.instruction.JumpInstruction.jump; +import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary; +import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields.declareCommonSubExpressionFields; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields.initializeCommonSubExpressionFields; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.collectCSEByLevel; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.rewriteExpressionWithCSE; import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.generateMethodsForLambda; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; public class CursorProcessorCompiler implements BodyCompiler { + private static Logger log = Logger.get(CursorProcessorCompiler.class); + private final Metadata metadata; + private final boolean isOptimizeCommonSubExpressions; - public CursorProcessorCompiler(Metadata metadata) + public CursorProcessorCompiler(Metadata metadata, boolean isOptimizeCommonSubExpressions) { this.metadata = metadata; + this.isOptimizeCommonSubExpressions = isOptimizeCommonSubExpressions; } @Override @@ -73,15 +93,36 @@ public void generateMethods(SqlFunctionProperties sqlFunctionProperties, ClassDe { CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); - generateProcessMethod(classDefinition, projections.size()); + List rowExpressions = ImmutableList.builder() + .addAll(projections) + .add(filter) + .build(); + + Map compiledLambdaMap = generateMethodsForLambda(classDefinition, callSiteBinder, cachedInstanceBinder, rowExpressions, metadata, sqlFunctionProperties, ""); + Map cseFields = ImmutableMap.of(); + if (isOptimizeCommonSubExpressions) { + Map> commonSubExpressionsByLevel = collectCSEByLevel(rowExpressions); + + if (!commonSubExpressionsByLevel.isEmpty()) { + cseFields = declareCommonSubExpressionFields(classDefinition, commonSubExpressionsByLevel); + generateCommonSubExpressionMethods(metadata, sqlFunctionProperties, classDefinition, callSiteBinder, cachedInstanceBinder, compiledLambdaMap, commonSubExpressionsByLevel, cseFields); + + Map commonSubExpressions = commonSubExpressionsByLevel.values().stream() + .flatMap(m -> m.entrySet().stream()) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + + projections = rewriteRowExpressionsWithCSE(projections, commonSubExpressions); + filter = rewriteRowExpressionsWithCSE(ImmutableList.of(filter), commonSubExpressions).get(0); + } + } + + generateProcessMethod(classDefinition, projections.size(), cseFields); - Map filterCompiledLambdaMap = generateMethodsForLambda(classDefinition, callSiteBinder, cachedInstanceBinder, filter, metadata, sqlFunctionProperties, "filter"); - generateFilterMethod(sqlFunctionProperties, classDefinition, callSiteBinder, cachedInstanceBinder, filterCompiledLambdaMap, filter); + generateFilterMethod(sqlFunctionProperties, classDefinition, callSiteBinder, cachedInstanceBinder, compiledLambdaMap, filter, cseFields); for (int i = 0; i < projections.size(); i++) { String methodName = "project_" + i; - Map projectCompiledLambdaMap = generateMethodsForLambda(classDefinition, callSiteBinder, cachedInstanceBinder, projections.get(i), metadata, sqlFunctionProperties, methodName); - generateProjectMethod(sqlFunctionProperties, classDefinition, callSiteBinder, cachedInstanceBinder, projectCompiledLambdaMap, methodName, projections.get(i)); + generateProjectMethod(sqlFunctionProperties, classDefinition, callSiteBinder, cachedInstanceBinder, compiledLambdaMap, methodName, projections.get(i), cseFields); } MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC)); @@ -91,11 +132,31 @@ public void generateMethods(SqlFunctionProperties sqlFunctionProperties, ClassDe .append(thisVariable) .invokeConstructor(Object.class); + initializeCommonSubExpressionFields(cseFields.values(), thisVariable, constructorBody); + cachedInstanceBinder.generateInitializations(thisVariable, constructorBody); constructorBody.ret(); } - private static void generateProcessMethod(ClassDefinition classDefinition, int projections) + List rewriteRowExpressionsWithCSE( + List rows, + Map commonSubExpressions) + { + if (!commonSubExpressions.isEmpty()) { + rows = rows.stream() + .map(p -> rewriteExpressionWithCSE(p, commonSubExpressions)) + .collect(toImmutableList()); + + if (log.isDebugEnabled()) { + log.debug("Extracted %d common sub-expressions", commonSubExpressions.size()); + commonSubExpressions.entrySet().forEach(entry -> log.debug("\t%s = %s", entry.getValue(), entry.getKey())); + log.debug("Rewrote Rows: %s", rows); + } + } + return rows; + } + + private static void generateProcessMethod(ClassDefinition classDefinition, int projections, Map cseFields) { Parameter properties = arg("properties", SqlFunctionProperties.class); Parameter yieldSignal = arg("yieldSignal", DriverYieldSignal.class); @@ -115,25 +176,32 @@ private static void generateProcessMethod(ClassDefinition classDefinition, int p // while loop loop body LabelNode done = new LabelNode("done"); + + BytecodeBlock whileFunctionBlock = new BytecodeBlock() + .comment("if (pageBuilder.isFull() || yieldSignal.isSet()) return new CursorProcessorOutput(completedPositions, false);") + .append(new IfStatement() + .condition(or( + pageBuilder.invoke("isFull", boolean.class), + yieldSignal.invoke("isSet", boolean.class))) + .ifTrue(jump(done))) + .comment("if (!cursor.advanceNextPosition()) return new CursorProcessorOutput(completedPositions, true);") + .append(new IfStatement() + .condition(cursor.invoke("advanceNextPosition", boolean.class)) + .ifFalse(new BytecodeBlock() + .putVariable(finishedVariable, true) + .gotoLabel(done))); + + // reset the CSE evaluatedField = false for every row + cseFields.values().forEach(field -> whileFunctionBlock.append(scope.getThis().setField(field.getEvaluatedField(), constantBoolean(false)))); + + whileFunctionBlock.comment("do the projection") + .append(createProjectIfStatement(classDefinition, method, properties, cursor, pageBuilder, projections)) + .comment("completedPositions++;") + .incrementVariable(completedPositionsVariable, (byte) 1); + WhileLoop whileLoop = new WhileLoop() .condition(constantTrue()) - .body(new BytecodeBlock() - .comment("if (pageBuilder.isFull() || yieldSignal.isSet()) return new CursorProcessorOutput(completedPositions, false);") - .append(new IfStatement() - .condition(or( - pageBuilder.invoke("isFull", boolean.class), - yieldSignal.invoke("isSet", boolean.class))) - .ifTrue(jump(done))) - .comment("if (!cursor.advanceNextPosition()) return new CursorProcessorOutput(completedPositions, true);") - .append(new IfStatement() - .condition(cursor.invoke("advanceNextPosition", boolean.class)) - .ifFalse(new BytecodeBlock() - .putVariable(finishedVariable, true) - .gotoLabel(done))) - .comment("do the projection") - .append(createProjectIfStatement(classDefinition, method, properties, cursor, pageBuilder, projections)) - .comment("completedPositions++;") - .incrementVariable(completedPositionsVariable, (byte) 1)); + .body(whileFunctionBlock); method.getBody() .append(whileLoop) @@ -194,7 +262,8 @@ private void generateFilterMethod( CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, Map compiledLambdaMap, - RowExpression filter) + RowExpression filter, + Map cseFields) { Parameter properties = arg("properties", SqlFunctionProperties.class); Parameter cursor = arg("cursor", RecordCursor.class); @@ -203,30 +272,30 @@ private void generateFilterMethod( method.comment("Filter: %s", filter); Scope scope = method.getScope(); - Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull"); + BytecodeBlock body = method.getBody(); + Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull"); RowExpressionCompiler compiler = new RowExpressionCompiler( classDefinition, callSiteBinder, cachedInstanceBinder, - fieldReferenceCompiler(cursor), + fieldReferenceCompiler(cseFields), metadata, sqlFunctionProperties, compiledLambdaMap); LabelNode end = new LabelNode("end"); - method.getBody() - .comment("boolean wasNull = false;") - .putVariable(wasNullVariable, false) - .comment("evaluate filter: " + filter) - .append(compiler.compile(filter, scope, Optional.empty())) - .comment("if (wasNull) return false;") - .getVariable(wasNullVariable) - .ifFalseGoto(end) - .pop(boolean.class) - .push(false) - .visitLabel(end) - .retBoolean(); + body.comment("boolean wasNull = false;") + .putVariable(wasNullVariable, false) + .comment("evaluate filter: " + filter) + .append(compiler.compile(filter, scope, Optional.empty())) + .comment("if (wasNull) return false;") + .getVariable(wasNullVariable) + .ifFalseGoto(end) + .pop(boolean.class) + .push(false) + .visitLabel(end) + .retBoolean(); } private void generateProjectMethod( @@ -236,7 +305,8 @@ private void generateProjectMethod( CachedInstanceBinder cachedInstanceBinder, Map compiledLambdaMap, String methodName, - RowExpression projection) + RowExpression projection, + Map cseFields) { Parameter properties = arg("properties", SqlFunctionProperties.class); Parameter cursor = arg("cursor", RecordCursor.class); @@ -246,17 +316,16 @@ private void generateProjectMethod( method.comment("Projection: %s", projection.toString()); Scope scope = method.getScope(); - Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull"); - RowExpressionCompiler compiler = new RowExpressionCompiler( classDefinition, callSiteBinder, cachedInstanceBinder, - fieldReferenceCompiler(cursor), + fieldReferenceCompiler(cseFields), metadata, sqlFunctionProperties, compiledLambdaMap); + Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull"); method.getBody() .comment("boolean wasNull = false;") .putVariable(wasNullVariable, false) @@ -265,7 +334,77 @@ private void generateProjectMethod( .ret(); } - private static RowExpressionVisitor fieldReferenceCompiler(Variable cursorVariable) + private List generateCommonSubExpressionMethods( + Metadata metadata, + SqlFunctionProperties sqlFunctionProperties, + ClassDefinition classDefinition, + CallSiteBinder callSiteBinder, + CachedInstanceBinder cachedInstanceBinder, + Map compiledLambdaMap, + Map> commonSubExpressionsByLevel, + Map commonSubExpressionFieldsMap) + { + Parameter properties = arg("properties", SqlFunctionProperties.class); + Parameter cursor = arg("cursor", RecordCursor.class); + + ImmutableList.Builder methods = ImmutableList.builder(); + Map cseMap = new HashMap<>(); + int startLevel = commonSubExpressionsByLevel.keySet().stream().reduce(Math::min).get(); + int maxLevel = commonSubExpressionsByLevel.keySet().stream().reduce(Math::max).get(); + for (int i = startLevel; i <= maxLevel; i++) { + if (commonSubExpressionsByLevel.containsKey(i)) { + for (Map.Entry entry : commonSubExpressionsByLevel.get(i).entrySet()) { + RowExpression cse = entry.getKey(); + Class type = Primitives.wrap(cse.getType().getJavaType()); + VariableReferenceExpression cseVariable = entry.getValue(); + CommonSubExpressionFields cseFields = commonSubExpressionFieldsMap.get(cseVariable); + MethodDefinition method = classDefinition.declareMethod( + a(PRIVATE), + "get" + cseVariable.getName(), + type(cseFields.getResultType()), + properties, + cursor); + + method.comment("cse: %s", cse); + + Scope scope = method.getScope(); + BytecodeBlock body = method.getBody(); + Variable thisVariable = method.getThis(); + + scope.declareVariable("wasNull", body, constantFalse()); + RowExpressionCompiler cseCompiler = new RowExpressionCompiler( + classDefinition, + callSiteBinder, + cachedInstanceBinder, + fieldReferenceCompiler(cseMap), + metadata, + sqlFunctionProperties, + compiledLambdaMap); + + IfStatement ifStatement = new IfStatement() + .condition(thisVariable.getField(cseFields.getEvaluatedField())) + .ifFalse(new BytecodeBlock() + .append(thisVariable) + .append(cseCompiler.compile(cse, scope, Optional.empty())) + .append(boxPrimitiveIfNecessary(scope, type)) + .putField(cseFields.getResultField()) + .append(thisVariable.setField(cseFields.getEvaluatedField(), constantBoolean(true)))); + + body.append(ifStatement) + .append(thisVariable) + .getField(cseFields.getResultField()) + .retObject(); + + methods.add(method); + cseMap.put(cseVariable, cseFields); + } + } + } + return methods.build(); + } + + static RowExpressionVisitor fieldReferenceCompiler( + Map variableMap) { return new RowExpressionVisitor() { @@ -275,6 +414,7 @@ public BytecodeNode visitInputReference(InputReferenceExpression node, Scope sco int field = node.getField(); Type type = node.getType(); Variable wasNullVariable = scope.getVariable("wasNull"); + Variable cursorVariable = scope.getVariable("cursor"); Class javaType = type.getJavaType(); if (!javaType.isPrimitive() && javaType != Slice.class) { @@ -321,7 +461,14 @@ public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Scope context @Override public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Scope context) { - throw new UnsupportedOperationException(); + CommonSubExpressionFields fields = variableMap.get(reference); + return new BytecodeBlock() + .append(context.getThis().invoke( + fields.getMethodName(), + fields.getResultType(), + context.getVariable("properties"), + context.getVariable("cursor"))) + .append(unboxPrimitiveIfNecessary(context, Primitives.wrap(reference.getType().getJavaType()))); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java index 5798a661044b2..b69a4d2c00862 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java @@ -67,7 +67,8 @@ public ExpressionCompiler(Metadata metadata, PageFunctionCompiler pageFunctionCo this.cursorProcessors = CacheBuilder.newBuilder() .recordStats() .maximumSize(1000) - .build(CacheLoader.from(key -> compile(key.getSqlFunctionProperties(), key.getFilter(), key.getProjections(), new CursorProcessorCompiler(metadata), CursorProcessor.class))); + .build(CacheLoader.from(key -> compile(key.getSqlFunctionProperties(), key.getFilter(), key.getProjections(), new CursorProcessorCompiler(metadata, key.isOptimizeCommonSubExpression()), CursorProcessor.class))); + this.cacheStatsMBean = new CacheStatsMBean(cursorProcessors); } @@ -77,10 +78,14 @@ public CacheStatsMBean getCursorProcessorCache() { return cacheStatsMBean; } - public Supplier compileCursorProcessor(SqlFunctionProperties sqlFunctionProperties, Optional filter, List projections, Object uniqueKey) { - Class cursorProcessor = cursorProcessors.getUnchecked(new CacheKey(sqlFunctionProperties, filter, projections, uniqueKey)); + return compileCursorProcessor(sqlFunctionProperties, filter, projections, uniqueKey, true); + } + + public Supplier compileCursorProcessor(SqlFunctionProperties sqlFunctionProperties, Optional filter, List projections, Object uniqueKey, boolean isOptimizeCommonSubExpression) + { + Class cursorProcessor = cursorProcessors.getUnchecked(new CacheKey(sqlFunctionProperties, filter, projections, uniqueKey, isOptimizeCommonSubExpression)); return () -> { try { return cursorProcessor.getConstructor().newInstance(); @@ -184,13 +189,15 @@ private static final class CacheKey private final Optional filter; private final List projections; private final Object uniqueKey; + private final boolean isOptimizeCommonSubExpression; - private CacheKey(SqlFunctionProperties sqlFunctionProperties, Optional filter, List projections, Object uniqueKey) + private CacheKey(SqlFunctionProperties sqlFunctionProperties, Optional filter, List projections, Object uniqueKey, boolean isOptimizeCommonSubExpression) { this.sqlFunctionProperties = sqlFunctionProperties; this.filter = filter; this.uniqueKey = uniqueKey; this.projections = ImmutableList.copyOf(projections); + this.isOptimizeCommonSubExpression = isOptimizeCommonSubExpression; } public SqlFunctionProperties getSqlFunctionProperties() @@ -208,6 +215,11 @@ private List getProjections() return projections; } + private boolean isOptimizeCommonSubExpression() + { + return isOptimizeCommonSubExpression; + } + @Override public int hashCode() { @@ -227,7 +239,8 @@ public boolean equals(Object obj) return Objects.equals(this.sqlFunctionProperties, other.sqlFunctionProperties) && Objects.equals(this.filter, other.filter) && Objects.equals(this.projections, other.projections) && - Objects.equals(this.uniqueKey, other.uniqueKey); + Objects.equals(this.uniqueKey, other.uniqueKey) && + Objects.equals(this.isOptimizeCommonSubExpression, other.isOptimizeCommonSubExpression); } @Override @@ -238,6 +251,7 @@ public String toString() .add("filter", filter) .add("projections", projections) .add("uniqueKey", uniqueKey) + .add("isOptimizeCommonSubExpression", isOptimizeCommonSubExpression) .toString(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java index 1141a48ca1900..b17938e957751 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java @@ -98,6 +98,9 @@ import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary; import static com.facebook.presto.sql.gen.BytecodeUtils.invoke; import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields.declareCommonSubExpressionFields; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields.initializeCommonSubExpressionFields; import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.collectCSEByLevel; import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.getExpressionsPartitionedByCSE; import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.rewriteExpressionWithCSE; @@ -451,7 +454,7 @@ private List generateCommonSubExpressionMethods( MethodDefinition method = classDefinition.declareMethod( a(PRIVATE), "get" + cseVariable.getName(), - type(cseFields.resultType), + type(cseFields.getResultType()), ImmutableList.builder() .add(properties) .add(page) @@ -476,17 +479,17 @@ private List generateCommonSubExpressionMethods( sqlFunctionProperties, compiledLambdaMap); IfStatement ifStatement = new IfStatement() - .condition(thisVariable.getField(cseFields.evaluatedField)) + .condition(thisVariable.getField(cseFields.getEvaluatedField())) .ifFalse(new BytecodeBlock() .append(thisVariable) .append(cseCompiler.compile(cse, scope, Optional.empty())) .append(boxPrimitiveIfNecessary(scope, type)) - .putField(cseFields.resultField) - .append(thisVariable.setField(cseFields.evaluatedField, constantBoolean(true)))); + .putField(cseFields.getResultField()) + .append(thisVariable.setField(cseFields.getEvaluatedField(), constantBoolean(true)))); body.append(ifStatement) .append(thisVariable) - .getField(cseFields.resultField) + .getField(cseFields.getResultField()) .retObject(); methods.add(method); @@ -529,7 +532,7 @@ private MethodDefinition generateEvaluateMethod( declareBlockVariables(projections, page, scope, body); Variable wasNull = scope.declareVariable("wasNull", body, constantFalse()); - cseFields.values().forEach(fields -> body.append(thisVariable.setField(fields.evaluatedField, constantBoolean(false)))); + cseFields.values().forEach(fields -> body.append(thisVariable.setField(fields.getEvaluatedField(), constantBoolean(false)))); RowExpressionCompiler compiler = new RowExpressionCompiler( classDefinition, @@ -744,7 +747,7 @@ private MethodDefinition generateFilterMethod( Variable thisVariable = scope.getThis(); declareBlockVariables(ImmutableList.of(filter), page, scope, body); - cseFields.values().forEach(fields -> body.append(thisVariable.setField(fields.evaluatedField, constantBoolean(false)))); + cseFields.values().forEach(fields -> body.append(thisVariable.setField(fields.getEvaluatedField(), constantBoolean(false)))); Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); RowExpressionCompiler compiler = new RowExpressionCompiler( @@ -764,14 +767,6 @@ private MethodDefinition generateFilterMethod( return method; } - private static void initializeCommonSubExpressionFields(Collection cseFields, Variable thisVariable, BytecodeBlock body) - { - cseFields.forEach(fields -> { - body.append(thisVariable.setField(fields.evaluatedField, constantBoolean(false))); - body.append(thisVariable.setField(fields.resultField, constantNull(fields.resultType))); - }); - } - private static void declareBlockVariables(List expressions, Parameter page, Scope scope, BytecodeBlock body) { for (int channel : getInputChannels(expressions)) { @@ -779,20 +774,6 @@ private static void declareBlockVariables(List expressions, Param } } - private static Map declareCommonSubExpressionFields(ClassDefinition classDefinition, Map> commonSubExpressionsByLevel) - { - ImmutableMap.Builder fields = ImmutableMap.builder(); - commonSubExpressionsByLevel.values().stream().map(Map::values).flatMap(Collection::stream).forEach(variable -> { - Class type = Primitives.wrap(variable.getType().getJavaType()); - fields.put(variable, new CommonSubExpressionFields( - classDefinition.declareField(a(PRIVATE), variable.getName() + "Evaluated", boolean.class), - classDefinition.declareField(a(PRIVATE), variable.getName() + "Result", type), - type, - "get" + variable.getName())); - }); - return fields.build(); - } - private static List getInputChannels(Iterable expressions) { TreeSet channels = new TreeSet<>(); @@ -813,22 +794,6 @@ private static int[] toIntArray(List list) return array; } - private static class CommonSubExpressionFields - { - private final FieldDefinition evaluatedField; - private final FieldDefinition resultField; - private final Class resultType; - private final String methodName; - - public CommonSubExpressionFields(FieldDefinition evaluatedField, FieldDefinition resultField, Class resultType, String methodName) - { - this.evaluatedField = evaluatedField; - this.resultField = resultField; - this.resultType = resultType; - this.methodName = methodName; - } - } - private static class FieldAndVariableReferenceCompiler implements RowExpressionVisitor { @@ -874,7 +839,7 @@ public BytecodeNode visitVariableReference(VariableReferenceExpression reference { CommonSubExpressionFields fields = variableMap.get(reference); return new BytecodeBlock() - .append(thisVariable.invoke(fields.methodName, fields.resultType, context.getVariable("properties"), context.getVariable("page"), context.getVariable("position"))) + .append(thisVariable.invoke(fields.getMethodName(), fields.getResultType(), context.getVariable("properties"), context.getVariable("page"), context.getVariable("position"))) .append(unboxPrimitiveIfNecessary(context, Primitives.wrap(reference.getType().getJavaType()))); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index a3055ca729446..21d843467aa95 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -1280,7 +1280,7 @@ private PhysicalOperation visitScanFilterAndProject( try { if (columns != null) { - Supplier cursorProcessor = expressionCompiler.compileCursorProcessor(session.getSqlFunctionProperties(), filterExpression, projections, sourceNode.getId()); + Supplier cursorProcessor = expressionCompiler.compileCursorProcessor(session.getSqlFunctionProperties(), filterExpression, projections, sourceNode.getId(), isOptimizeCommonSubExpressions(session)); Supplier pageProcessor = expressionCompiler.compilePageProcessor(session.getSqlFunctionProperties(), filterExpression, projections, isOptimizeCommonSubExpressions(session), Optional.of(context.getStageExecutionId() + "_" + planNodeId)); SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperatorFactory( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/CommonSubExpressionBenchmark.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/CommonSubExpressionBenchmark.java index e3b9c961c0726..f79945f9bdd0b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/CommonSubExpressionBenchmark.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/CommonSubExpressionBenchmark.java @@ -16,6 +16,7 @@ import com.facebook.presto.SequencePageBuilder; import com.facebook.presto.Session; import com.facebook.presto.common.Page; +import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.block.DictionaryBlock; import com.facebook.presto.common.type.Type; @@ -23,7 +24,10 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.operator.DriverYieldSignal; +import com.facebook.presto.operator.index.PageRecordSet; +import com.facebook.presto.operator.project.CursorProcessor; import com.facebook.presto.operator.project.PageProcessor; +import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; @@ -66,6 +70,7 @@ import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; +import static java.util.stream.Collectors.toList; @State(Scope.Thread) @OutputTimeUnit(TimeUnit.NANOSECONDS) @@ -82,9 +87,11 @@ public class CommonSubExpressionBenchmark private static final int POSITIONS = 1024; private PageProcessor pageProcessor; + private CursorProcessor cursorProcessor; private Page inputPage; private Map symbolTypes; private Map sourceLayout; + private List projectionTypes; @Param({"json", "bigint", "varchar"}) String functionType; @@ -107,13 +114,17 @@ public void setup() List projections = getProjections(this.functionType); + projectionTypes = projections.stream().map(RowExpression::getType).collect(toList()); + MetadataManager metadata = createTestMetadataManager(); PageFunctionCompiler pageFunctionCompiler = new PageFunctionCompiler(metadata, 0); - pageProcessor = new ExpressionCompiler(metadata, pageFunctionCompiler).compilePageProcessor(TEST_SESSION.getSqlFunctionProperties(), Optional.of(getFilter(functionType)), projections, optimizeCommonSubExpression, Optional.empty()).get(); + ExpressionCompiler expressionCompiler = new ExpressionCompiler(metadata, pageFunctionCompiler); + pageProcessor = expressionCompiler.compilePageProcessor(TEST_SESSION.getSqlFunctionProperties(), Optional.of(getFilter(functionType)), projections, optimizeCommonSubExpression, Optional.empty()).get(); + cursorProcessor = expressionCompiler.compileCursorProcessor(TEST_SESSION.getSqlFunctionProperties(), Optional.of(getFilter(functionType)), projections, "key", optimizeCommonSubExpression).get(); } @Benchmark - public List> compute() + public List> computePage() { return ImmutableList.copyOf( pageProcessor.process( @@ -123,6 +134,22 @@ public List> compute() inputPage)); } + @Benchmark + public Optional ComputeRecordSet() + { + List types = ImmutableList.of(TYPE_MAP.get(this.functionType)); + PageBuilder pageBuilder = new PageBuilder(projectionTypes); + RecordSet recordSet = new PageRecordSet(types, inputPage); + + cursorProcessor.process( + null, + new DriverYieldSignal(), + recordSet.cursor(), + pageBuilder); + + return Optional.of(pageBuilder.build()); + } + private RowExpression getFilter(String functionType) { if (functionType.equals("varchar")) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestCursorProcessorCompiler.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestCursorProcessorCompiler.java new file mode 100644 index 0000000000000..aba6cf8cfa82e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestCursorProcessorCompiler.java @@ -0,0 +1,216 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.gen; + +import com.facebook.presto.bytecode.ClassDefinition; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.FunctionManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.operator.DriverYieldSignal; +import com.facebook.presto.operator.index.PageRecordSet; +import com.facebook.presto.operator.project.CursorProcessor; +import com.facebook.presto.spi.RecordSet; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static com.facebook.presto.bytecode.Access.FINAL; +import static com.facebook.presto.bytecode.Access.PUBLIC; +import static com.facebook.presto.bytecode.Access.a; +import static com.facebook.presto.bytecode.ParameterizedType.type; +import static com.facebook.presto.common.function.OperatorType.ADD; +import static com.facebook.presto.common.function.OperatorType.GREATER_THAN; +import static com.facebook.presto.common.function.OperatorType.LESS_THAN; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields.declareCommonSubExpressionFields; +import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.collectCSEByLevel; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.relational.Expressions.field; +import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static com.facebook.presto.util.CompilerUtils.makeClassName; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.stream.Collectors.toList; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestCursorProcessorCompiler +{ + private static final Metadata METADATA = createTestMetadataManager(); + private static final FunctionManager FUNCTION_MANAGER = METADATA.getFunctionManager(); + + private static final CallExpression ADD_X_Y = call( + ADD.name(), + FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)), + BIGINT, + field(0, BIGINT), + field(1, BIGINT)); + + private static final CallExpression ADD_X_Y_GREATER_THAN_2 = call( + GREATER_THAN.name(), + FUNCTION_MANAGER.resolveOperator(GREATER_THAN, fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ADD_X_Y, + constant(2L, BIGINT)); + + private static final CallExpression ADD_X_Y_LESS_THAN_10 = call( + LESS_THAN.name(), + FUNCTION_MANAGER.resolveOperator(LESS_THAN, fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ADD_X_Y, + constant(10L, BIGINT)); + + private static final CallExpression ADD_X_Y_Z = call( + ADD.name(), + FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)), + BIGINT, + call( + ADD.name(), + FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)), + BIGINT, + field(0, BIGINT), + field(1, BIGINT)), + field(2, BIGINT)); + + @Test + public void testRewriteRowExpressionWithCSE() + { + CursorProcessorCompiler cseCursorCompiler = new CursorProcessorCompiler(METADATA, true); + + ClassDefinition cursorProcessorClassDefinition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName(CursorProcessor.class.getSimpleName()), + type(Object.class), + type(CursorProcessor.class)); + + RowExpression filter = new SpecialFormExpression(AND, BIGINT, ADD_X_Y_GREATER_THAN_2); + List projections = ImmutableList.of(ADD_X_Y_Z); + List rowExpressions = ImmutableList.builder() + .addAll(projections) + .add(filter) + .build(); + Map> commonSubExpressionsByLevel = collectCSEByLevel(rowExpressions); + + Map cseFields = declareCommonSubExpressionFields(cursorProcessorClassDefinition, commonSubExpressionsByLevel); + Map commonSubExpressions = commonSubExpressionsByLevel.values().stream() + .flatMap(m -> m.entrySet().stream()) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + // X+Y as CSE + assertEquals(1, cseFields.size()); + VariableReferenceExpression cseVariable = cseFields.keySet().iterator().next(); + + RowExpression rewrittenFilter = cseCursorCompiler.rewriteRowExpressionsWithCSE(ImmutableList.of(filter), commonSubExpressions).get(0); + + List rewrittenProjections = cseCursorCompiler.rewriteRowExpressionsWithCSE(projections, commonSubExpressions); + + // X+Y+Z contains CSE X+Y + assertTrue(((CallExpression) rewrittenProjections.get(0)).getArguments().contains(cseVariable)); + + // X+Y > 2 consists CSE X+Y + assertTrue(((CallExpression) ((SpecialFormExpression) rewrittenFilter).getArguments().get(0)).getArguments().contains(cseVariable)); + } + + @Test + public void testCompilerWithCSE() + { + PageFunctionCompiler functionCompiler = new PageFunctionCompiler(METADATA, 0); + ExpressionCompiler expressionCompiler = new ExpressionCompiler(METADATA, functionCompiler); + + RowExpression filter = new SpecialFormExpression(AND, BIGINT, ADD_X_Y_GREATER_THAN_2, ADD_X_Y_LESS_THAN_10); + List projections = createIfProjectionList(5); + + Supplier cseCursorProcessorSupplier = expressionCompiler.compileCursorProcessor(SESSION.getSqlFunctionProperties(), Optional.of(filter), projections, "key", true); + Supplier noCseSECursorProcessorSupplier = expressionCompiler.compileCursorProcessor(SESSION.getSqlFunctionProperties(), Optional.of(filter), projections, "key", false); + + Page input = createLongBlockPage(2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + + List types = ImmutableList.of(BIGINT, BIGINT); + PageBuilder pageBuilder = new PageBuilder(projections.stream().map(RowExpression::getType).collect(toList())); + RecordSet recordSet = new PageRecordSet(types, input); + cseCursorProcessorSupplier.get().process(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder); + + Page pageFromCSE = pageBuilder.build(); + pageBuilder.reset(); + + noCseSECursorProcessorSupplier.get().process(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder); + Page pageFromNoCSE = pageBuilder.build(); + + checkPageEqual(pageFromCSE, pageFromNoCSE); + } + + private static Page createLongBlockPage(int blockCount, long... values) + { + Block[] blocks = new Block[blockCount]; + for (int i = 0; i < blockCount; i++) { + BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(values.length); + for (long value : values) { + BIGINT.writeLong(builder, value); + } + blocks[i] = builder.build(); + } + return new Page(blocks); + } + + private List createIfProjectionList(int projectionCount) + { + return IntStream.range(0, projectionCount) + .mapToObj(i -> new SpecialFormExpression( + IF, + BIGINT, + call( + GREATER_THAN.name(), + FUNCTION_MANAGER.resolveOperator(GREATER_THAN, fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ADD_X_Y, + constant(8L, BIGINT)), + constant((long) i, BIGINT), + constant((long) i + 1, BIGINT))) + .collect(toImmutableList()); + } + + private void checkBlockEqual(Block a, Block b) + { + assertEquals(a.getPositionCount(), b.getPositionCount()); + for (int i = 0; i < a.getPositionCount(); i++) { + assertEquals(a.getLong(i), b.getLong(i)); + } + } + + private void checkPageEqual(Page a, Page b) + { + assertEquals(a.getPositionCount(), b.getPositionCount()); + for (int i = 0; i < a.getPositionCount(); i++) { + checkBlockEqual(a.getBlock(i), b.getBlock(i)); + } + } +}