Skip to content

Commit d809219

Browse files
wangyumGitHub Enterprise
authored andcommitted
[CARMEL-6265] Only push down low cost expression (#1081)
* Only push down low cost expression * fix
1 parent c67115a commit d809219

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushPartialAggregationThroughJoin.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,20 +186,24 @@ object PushPartialAggregationThroughJoin extends Rule[LogicalPlan]
186186
}).asInstanceOf[NamedExpression]
187187
}
188188

189+
private def lowerCostExp(ae: AggregateExpression): Boolean = {
190+
PredicateReorder.expressionCost(ae) <= 100
191+
}
192+
189193
private def pushableAggExp(ae: AggregateExpression): Boolean = ae match {
190-
case AggregateExpression(_: Sum, Complete, false, None, _) => true
191-
case AggregateExpression(_: Min, Complete, false, None, _) => true
192-
case AggregateExpression(_: Max, Complete, false, None, _) => true
193-
case AggregateExpression(_: First, Complete, false, None, _) => true
194-
case AggregateExpression(_: Last, Complete, false, None, _) => true
194+
case AggregateExpression(_: Sum, Complete, false, None, _) => lowerCostExp(ae)
195+
case AggregateExpression(_: Min, Complete, false, None, _) => lowerCostExp(ae)
196+
case AggregateExpression(_: Max, Complete, false, None, _) => lowerCostExp(ae)
197+
case AggregateExpression(_: First, Complete, false, None, _) => lowerCostExp(ae)
198+
case AggregateExpression(_: Last, Complete, false, None, _) => lowerCostExp(ae)
195199
case AggregateExpression(Average(e), Complete, false, None, _) =>
196-
e.dataType.isInstanceOf[NumericType]
200+
e.dataType.isInstanceOf[NumericType] && lowerCostExp(ae)
197201
case _ => false
198202
}
199203

200204
// Support count(*), count(id)
201205
private def pushableCountExp(ae: AggregateExpression): Boolean = ae match {
202-
case AggregateExpression(_: Count, Complete, false, None, _) => true
206+
case AggregateExpression(_: Count, Complete, false, None, _) => lowerCostExp(ae)
203207
case _ => false
204208
}
205209

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushPartialAggregationThroughJoinSuite.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.InConversion
2222
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
2424
import org.apache.spark.sql.catalyst.dsl.plans._
25-
import org.apache.spark.sql.catalyst.expressions.{Cast, CheckOverflow, CheckOverflowInSum, Divide, Expression, Literal, PromotePrecision}
25+
import org.apache.spark.sql.catalyst.expressions.{Cast, CheckOverflow, CheckOverflowInSum, Divide, Expression, If, Literal, PromotePrecision}
2626
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Sum}
2727
import org.apache.spark.sql.catalyst.optimizer.customAnalyze._
2828
import org.apache.spark.sql.catalyst.plans._
2929
import org.apache.spark.sql.catalyst.plans.logical._
3030
import org.apache.spark.sql.catalyst.rules.RuleExecutor
3131
import org.apache.spark.sql.connector.catalog.CatalogManager
3232
import org.apache.spark.sql.internal.SQLConf
33-
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, LongType}
33+
import org.apache.spark.sql.types._
3434

3535
// Custom Analyzer to exclude DecimalPrecision rule
3636
object ExcludeDecimalPrecisionAnalyzer extends Analyzer(
@@ -636,4 +636,13 @@ class PushPartialAggregationThroughJoinSuite extends PlanTest {
636636
comparePlans(Optimize.execute(originalQuery), ColumnPruning(originalQuery))
637637
}
638638
}
639+
640+
test("Do not push down aggregate expressions if it's not lower cost expression") {
641+
val originalQuery = testRelation1
642+
.join(testRelation2, joinType = Inner, condition = Some('a === 'x))
643+
.groupBy()(sum(If('y.cast(StringType) likeAny("1", "2"), 1, 0)).as("sum_y"))
644+
.analyze
645+
646+
comparePlans(Optimize.execute(originalQuery), ColumnPruning(originalQuery))
647+
}
639648
}

0 commit comments

Comments
 (0)