diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.java new file mode 100644 index 0000000000000..f72f77ba6bc76 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.java @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.util.Preconditions; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.Aggregate.Group; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.runtime.Utilities; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.mapping.Mappings; +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** Planner rule that removes unreferenced AggregateCall from Aggregate. */ +public abstract class PruneAggregateCallRule + extends RelRule { + + public static final ProjectPruneAggregateCallRule PROJECT_ON_AGGREGATE = + ProjectPruneAggregateCallRule.ProjectPruneAggregateCallRuleConfig.DEFAULT.toRule(); + public static final CalcPruneAggregateCallRule CALC_ON_AGGREGATE = + CalcPruneAggregateCallRule.CalcPruneAggregateCallRuleConfig.DEFAULT.toRule(); + + protected PruneAggregateCallRule(PruneAggregateCallRule.PruneAggregateCallRuleConfig config) { + super(config); + } + + protected abstract ImmutableBitSet getInputRefs(T relOnAgg); + + @Override + public boolean matches(RelOptRuleCall call) { + T relOnAgg = call.rel(0); + Aggregate agg = call.rel(1); + if (agg.getGroupType() != Group.SIMPLE + || agg.getAggCallList().isEmpty() + || + // at least output one column + (agg.getGroupCount() == 0 && agg.getAggCallList().size() == 1)) { + return false; + } + ImmutableBitSet inputRefs = getInputRefs(relOnAgg); + int[] unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg); + return unrefAggCallIndices.length > 0; + } + + private int[] getUnrefAggCallIndices(ImmutableBitSet inputRefs, Aggregate agg) { + int groupCount = agg.getGroupCount(); + return IntStream.range(0, agg.getAggCallList().size()) + .filter(index -> !inputRefs.get(groupCount + index)) + .toArray(); + } + + @Override + public void onMatch(RelOptRuleCall call) { + T relOnAgg = call.rel(0); + Aggregate agg = call.rel(1); + ImmutableBitSet inputRefs = getInputRefs(relOnAgg); + int[] unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg); + Preconditions.checkArgument(unrefAggCallIndices.length > 0, "requirement failed"); + + List newAggCalls = new ArrayList<>(agg.getAggCallList()); + // remove unreferenced AggCall from original aggCalls + Arrays.stream(unrefAggCallIndices) + .boxed() + .sorted(Comparator.reverseOrder()) + .forEach(index -> newAggCalls.remove((int) index)); + + if (newAggCalls.isEmpty() && agg.getGroupCount() == 0) { + // at least output one column + newAggCalls.add(agg.getAggCallList().get(0)); + unrefAggCallIndices = + Arrays.copyOfRange(unrefAggCallIndices, 1, unrefAggCallIndices.length); + } + + Aggregate newAgg = + agg.copy( + agg.getTraitSet(), + agg.getInput(), + agg.getGroupSet(), + List.of(agg.getGroupSet()), + newAggCalls); + + int newFieldIndex = 0; + // map old agg output index to new agg output index + Map mapOldToNew = new HashMap<>(); + int fieldCountOfOldAgg = agg.getRowType().getFieldCount(); + List unrefAggCallOutputIndices = + Arrays.stream(unrefAggCallIndices) + .mapToObj(i -> i + agg.getGroupCount()) + .collect(Collectors.toList()); + for (int i = 0; i < fieldCountOfOldAgg; i++) { + if (!unrefAggCallOutputIndices.contains(i)) { + mapOldToNew.put(i, newFieldIndex); + newFieldIndex++; + } + } + Preconditions.checkArgument( + mapOldToNew.size() == newAgg.getRowType().getFieldCount(), "requirement failed"); + + Mappings.TargetMapping mapping = + Mappings.target( + mapOldToNew, fieldCountOfOldAgg, newAgg.getRowType().getFieldCount()); + RelNode newRelOnAgg = createNewRel(mapping, relOnAgg, newAgg); + call.transformTo(newRelOnAgg); + } + + protected abstract RelNode createNewRel( + Mappings.TargetMapping mapping, T project, RelNode newAgg); + + public static class ProjectPruneAggregateCallRule extends PruneAggregateCallRule { + + protected ProjectPruneAggregateCallRule(ProjectPruneAggregateCallRuleConfig config) { + super(config); + } + + @Override + protected ImmutableBitSet getInputRefs(Project relOnAgg) { + return RelOptUtil.InputFinder.bits(relOnAgg.getProjects(), null); + } + + @Override + protected RelNode createNewRel( + Mappings.TargetMapping mapping, Project project, RelNode newAgg) { + List newProjects = RexUtil.apply(mapping, project.getProjects()); + if (projectsOnlyIdentity(newProjects, newAgg.getRowType().getFieldCount()) + && Utilities.compare( + project.getRowType().getFieldNames(), + newAgg.getRowType().getFieldNames()) + == 0) { + return newAgg; + } else { + return project.copy( + project.getTraitSet(), newAgg, newProjects, project.getRowType()); + } + } + + private boolean projectsOnlyIdentity(List projects, int inputFieldCount) { + if (projects.size() != inputFieldCount) { + return false; + } + return IntStream.range(0, projects.size()) + .allMatch( + index -> { + RexNode project = projects.get(index); + if (project instanceof RexInputRef) { + RexInputRef r = (RexInputRef) project; + return r.getIndex() == index; + } + return false; + }); + } + + /** Rule configuration. */ + @Value.Immutable(singleton = false) + public interface ProjectPruneAggregateCallRuleConfig + extends PruneAggregateCallRule.PruneAggregateCallRuleConfig { + ProjectPruneAggregateCallRuleConfig DEFAULT = + ImmutableProjectPruneAggregateCallRuleConfig.builder() + .build() + .withOperandSupplier( + b0 -> + b0.operand(Project.class) + .oneInput( + b1 -> + b1.operand(Aggregate.class) + .anyInputs())) + .withDescription( + "PruneAggregateCallRule_" + Project.class.getCanonicalName()); + + @Override + default ProjectPruneAggregateCallRule toRule() { + return new ProjectPruneAggregateCallRule(this); + } + } + } + + public static class CalcPruneAggregateCallRule extends PruneAggregateCallRule { + + protected CalcPruneAggregateCallRule(CalcPruneAggregateCallRuleConfig config) { + super(config); + } + + @Override + protected ImmutableBitSet getInputRefs(Calc calc) { + RexProgram program = calc.getProgram(); + RexNode condition = + program.getCondition() != null + ? program.expandLocalRef(program.getCondition()) + : null; + List projects = + program.getProjectList().stream() + .map(program::expandLocalRef) + .collect(Collectors.toList()); + return RelOptUtil.InputFinder.bits(projects, condition); + } + + @Override + protected RelNode createNewRel(Mappings.TargetMapping mapping, Calc calc, RelNode newAgg) { + RexProgram program = calc.getProgram(); + RexNode newCondition = + program.getCondition() != null + ? RexUtil.apply(mapping, program.expandLocalRef(program.getCondition())) + : null; + List projects = + program.getProjectList().stream() + .map(program::expandLocalRef) + .collect(Collectors.toList()); + List newProjects = RexUtil.apply(mapping, projects); + RexProgram newProgram = + RexProgram.create( + newAgg.getRowType(), + newProjects, + newCondition, + program.getOutputRowType().getFieldNames(), + calc.getCluster().getRexBuilder()); + if (newProgram.isTrivial() + && Utilities.compare( + calc.getRowType().getFieldNames(), + newAgg.getRowType().getFieldNames()) + == 0) { + return newAgg; + } else { + return calc.copy(calc.getTraitSet(), newAgg, newProgram); + } + } + + /** Rule configuration. */ + @Value.Immutable(singleton = false) + public interface CalcPruneAggregateCallRuleConfig extends PruneAggregateCallRuleConfig { + CalcPruneAggregateCallRuleConfig DEFAULT = + ImmutableCalcPruneAggregateCallRuleConfig.builder() + .build() + .withOperandSupplier( + b0 -> + b0.operand(Calc.class) + .oneInput( + b1 -> + b1.operand(Aggregate.class) + .anyInputs())) + .withDescription( + "PruneAggregateCallRule_" + Calc.class.getCanonicalName()); + + @Override + default CalcPruneAggregateCallRule toRule() { + return new CalcPruneAggregateCallRule(this); + } + } + } + + /** Rule configuration. */ + public interface PruneAggregateCallRuleConfig extends RelRule.Config { + @Override + PruneAggregateCallRule toRule(); + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.scala deleted file mode 100644 index bd7c479fea301..0000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.scala +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.rules.logical - -import com.google.common.collect.{ImmutableList, Maps} -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil} -import org.apache.calcite.plan.RelOptRule.{any, operand} -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.core.{Aggregate, AggregateCall, Calc, Project, RelFactories} -import org.apache.calcite.rel.core.Aggregate.Group -import org.apache.calcite.rex.{RexInputRef, RexNode, RexProgram, RexUtil} -import org.apache.calcite.runtime.Utilities -import org.apache.calcite.util.ImmutableBitSet -import org.apache.calcite.util.mapping.Mappings - -import java.util - -import scala.collection.JavaConversions._ - -/** Planner rule that removes unreferenced AggregateCall from Aggregate */ -abstract class PruneAggregateCallRule[T <: RelNode](topClass: Class[T]) - extends RelOptRule( - operand(topClass, operand(classOf[Aggregate], any)), - RelFactories.LOGICAL_BUILDER, - s"PruneAggregateCallRule_${topClass.getCanonicalName}") { - - protected def getInputRefs(relOnAgg: T): ImmutableBitSet - - override def matches(call: RelOptRuleCall): Boolean = { - val relOnAgg: T = call.rel(0) - val agg: Aggregate = call.rel(1) - if ( - agg.getGroupType != Group.SIMPLE || agg.getAggCallList.isEmpty || - // at least output one column - (agg.getGroupCount == 0 && agg.getAggCallList.size() == 1) - ) { - return false - } - val inputRefs = getInputRefs(relOnAgg) - val unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg) - unrefAggCallIndices.nonEmpty - } - - private def getUnrefAggCallIndices(inputRefs: ImmutableBitSet, agg: Aggregate): Array[Int] = { - val groupCount = agg.getGroupCount - agg.getAggCallList.indices - .flatMap { - index => - val aggCallOutputIndex = groupCount + index - if (inputRefs.get(aggCallOutputIndex)) { - Array.empty[Int] - } else { - Array(index) - } - } - .toArray[Int] - } - - override def onMatch(call: RelOptRuleCall): Unit = { - val relOnAgg: T = call.rel(0) - val agg: Aggregate = call.rel(1) - val inputRefs = getInputRefs(relOnAgg) - var unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg) - require(unrefAggCallIndices.nonEmpty) - - val newAggCalls: util.List[AggregateCall] = new util.ArrayList(agg.getAggCallList) - // remove unreferenced AggCall from original aggCalls - unrefAggCallIndices.sorted.reverse.foreach(i => newAggCalls.remove(i)) - - if (newAggCalls.isEmpty && agg.getGroupCount == 0) { - // at least output one column - newAggCalls.add(agg.getAggCallList.get(0)) - unrefAggCallIndices = unrefAggCallIndices.slice(1, unrefAggCallIndices.length) - } - - val newAgg = agg.copy( - agg.getTraitSet, - agg.getInput, - agg.getGroupSet, - ImmutableList.of(agg.getGroupSet), - newAggCalls - ) - - var newFieldIndex = 0 - // map old agg output index to new agg output index - val mapOldToNew = Maps.newHashMap[Integer, Integer]() - val fieldCountOfOldAgg = agg.getRowType.getFieldCount - val unrefAggCallOutputIndices = unrefAggCallIndices.map(_ + agg.getGroupCount) - (0 until fieldCountOfOldAgg).foreach { - i => - if (!unrefAggCallOutputIndices.contains(i)) { - mapOldToNew.put(i, newFieldIndex) - newFieldIndex += 1 - } - } - require(mapOldToNew.size() == newAgg.getRowType.getFieldCount) - - val mapping = Mappings.target(mapOldToNew, fieldCountOfOldAgg, newAgg.getRowType.getFieldCount) - val newRelOnAgg = createNewRel(mapping, relOnAgg, newAgg) - call.transformTo(newRelOnAgg) - } - - protected def createNewRel(mapping: Mappings.TargetMapping, project: T, newAgg: RelNode): RelNode -} - -class ProjectPruneAggregateCallRule extends PruneAggregateCallRule(classOf[Project]) { - override protected def getInputRefs(relOnAgg: Project): ImmutableBitSet = { - RelOptUtil.InputFinder.bits(relOnAgg.getProjects, null) - } - - override protected def createNewRel( - mapping: Mappings.TargetMapping, - project: Project, - newAgg: RelNode): RelNode = { - val newProjects = RexUtil.apply(mapping, project.getProjects).toList - if ( - projectsOnlyIdentity(newProjects, newAgg.getRowType.getFieldCount) && - Utilities.compare(project.getRowType.getFieldNames, newAgg.getRowType.getFieldNames) == 0 - ) { - newAgg - } else { - project.copy(project.getTraitSet, newAgg, newProjects, project.getRowType) - } - } - - private def projectsOnlyIdentity(projects: util.List[RexNode], inputFieldCount: Int): Boolean = { - if (projects.size != inputFieldCount) { - return false - } - projects.zipWithIndex.forall { - case (project, index) => - project match { - case r: RexInputRef => r.getIndex == index - case _ => false - } - } - } -} - -class CalcPruneAggregateCallRule extends PruneAggregateCallRule(classOf[Calc]) { - override protected def getInputRefs(relOnAgg: Calc): ImmutableBitSet = { - val program = relOnAgg.getProgram - val condition = if (program.getCondition != null) { - program.expandLocalRef(program.getCondition) - } else { - null - } - val projects = program.getProjectList.map(program.expandLocalRef) - RelOptUtil.InputFinder.bits(projects, condition) - } - - override protected def createNewRel( - mapping: Mappings.TargetMapping, - calc: Calc, - newAgg: RelNode): RelNode = { - val program = calc.getProgram - val newCondition = if (program.getCondition != null) { - RexUtil.apply(mapping, program.expandLocalRef(program.getCondition)) - } else { - null - } - val projects = program.getProjectList.map(program.expandLocalRef) - val newProjects = RexUtil.apply(mapping, projects).toList - val newProgram = RexProgram.create( - newAgg.getRowType, - newProjects, - newCondition, - program.getOutputRowType.getFieldNames, - calc.getCluster.getRexBuilder - ) - if ( - newProgram.isTrivial && - Utilities.compare(calc.getRowType.getFieldNames, newAgg.getRowType.getFieldNames) == 0 - ) { - newAgg - } else { - calc.copy(calc.getTraitSet, newAgg, newProgram) - } - } -} - -object PruneAggregateCallRule { - val PROJECT_ON_AGGREGATE = new ProjectPruneAggregateCallRule - val CALC_ON_AGGREGATE = new CalcPruneAggregateCallRule -}