@@ -25,7 +25,7 @@ import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
2525import org .apache .spark .sql .catalyst .plans ._
2626import org .apache .spark .sql .catalyst .rules .Rule
2727import org .apache .spark .sql .execution ._
28- import org .apache .spark .sql .execution .aggregate .HashAggregateExec
28+ import org .apache .spark .sql .execution .aggregate .{ HashAggregateExec , ObjectHashAggregateExec , SortAggregateExec }
2929import org .apache .spark .sql .execution .exchange .{EnsureRequirements , ShuffleExchangeExec }
3030import org .apache .spark .sql .execution .joins .SortMergeJoinExec
3131import org .apache .spark .sql .internal .SQLConf
@@ -133,13 +133,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
133133
134134 private def canSplitLeftSide (joinType : JoinType , plan : SparkPlan ) = {
135135 (joinType == Inner || joinType == Cross || joinType == LeftSemi ||
136- joinType == LeftAnti || joinType == LeftOuter ) &&
137- plan.find(_.isInstanceOf [HashAggregateExec ]).isEmpty
136+ joinType == LeftAnti || joinType == LeftOuter ) && ! containsAggregateExec(plan)
138137 }
139138
140139 private def canSplitRightSide (joinType : JoinType , plan : SparkPlan ) = {
141- (joinType == Inner || joinType == Cross || joinType == RightOuter ) &&
142- plan.find(_.isInstanceOf [HashAggregateExec ]).isEmpty
140+ (joinType == Inner || joinType == Cross ||
141+ joinType == RightOuter ) && ! containsAggregateExec(plan)
142+ }
143+
144+ private def containsAggregateExec (plan : SparkPlan ) = {
145+ plan.find {
146+ case _ : HashAggregateExec => true
147+ case _ : SortAggregateExec => true
148+ case _ : ObjectHashAggregateExec => true
149+ case _ => false
150+ }.isDefined
143151 }
144152
145153 private def getSizeInfo (medianSize : Long , sizes : Seq [Long ]): String = {
0 commit comments