@@ -544,7 +544,7 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
544544 val expectedAnswerForRightOuter =
545545 spark
546546 .range(0 , 100 )
547- .flatMap(i => Seq .fill(100 )(i))
547+ .flatMap(i => Seq .fill(100 )(i))
548548 .selectExpr(" 0 as key" , " value" )
549549 checkAnswer(
550550 rightOuterJoin,
@@ -578,6 +578,153 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
578578 }
579579 }
580580
581+ test(" adaptive skewed join: left semi/anti join and skewed on right side" ) {
582+ val spark = defaultSparkSession
583+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_JOIN_ENABLED .key, " false" )
584+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED .key, " true" )
585+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD .key, 10 )
586+ withSparkSession(spark) { spark : SparkSession =>
587+ val df1 =
588+ spark
589+ .range(0 , 10 , 1 , 2 )
590+ .selectExpr(" id % 5 as key1" , " id as value1" )
591+ val df2 =
592+ spark
593+ .range(0 , 1000 , 1 , numInputPartitions)
594+ .selectExpr(" id % 1 as key2" , " id as value2" )
595+
596+ val leftSemiJoin =
597+ df1.join(df2, col(" key1" ) === col(" key2" ), " left_semi" ).select(col(" key1" ), col(" value1" ))
598+ val leftAntiJoin =
599+ df1.join(df2, col(" key1" ) === col(" key2" ), " left_anti" ).select(col(" key1" ), col(" value1" ))
600+
601+ // Before Execution, there is one SortMergeJoin
602+ val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
603+ case smj : SortMergeJoinExec => smj
604+ }
605+ assert(smjBeforeExecutionForLeftSemi.length === 1 )
606+
607+ val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect {
608+ case smj : SortMergeJoinExec => smj
609+ }
610+ assert(smjBeforeExecutionForLeftAnti.length === 1 )
611+
612+ // Check the answer.
613+ val expectedAnswerForLeftSemi =
614+ spark
615+ .range(0 , 10 )
616+ .filter(_ % 5 == 0 )
617+ .selectExpr(" id % 5 as key" , " id as value" )
618+ checkAnswer(
619+ leftSemiJoin,
620+ expectedAnswerForLeftSemi.collect())
621+
622+ val expectedAnswerForLeftAnti =
623+ spark
624+ .range(0 , 10 )
625+ .filter(_ % 5 != 0 )
626+ .selectExpr(" id % 5 as key" , " id as value" )
627+ checkAnswer(
628+ leftAntiJoin,
629+ expectedAnswerForLeftAnti.collect())
630+
631+ // For the left outer join case: during execution, the SMJ can not be translated to any sub
632+ // joins due to the skewed side is on the right but the join type is left semi
633+ // (not correspond with each other)
634+ val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
635+ case smj : SortMergeJoinExec => smj
636+ }
637+ assert(smjAfterExecutionForLeftSemi.length === 1 )
638+
639+ // For the right outer join case: during execution, the SMJ can not be translated to any sub
640+ // joins due to the skewed side is on the right but the join type is left anti
641+ // (not correspond with each other)
642+ val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect {
643+ case smj : SortMergeJoinExec => smj
644+ }
645+ assert(smjAfterExecutionForLeftAnti.length === 1 )
646+
647+ }
648+ }
649+
650+ test(" adaptive skewed join: left semi/anti join and skewed on left side" ) {
651+ val spark = defaultSparkSession
652+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_JOIN_ENABLED .key, " false" )
653+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED .key, " true" )
654+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD .key, 10 )
655+ val MAX_SPLIT = 5
656+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS .key, MAX_SPLIT )
657+ withSparkSession(spark) { spark : SparkSession =>
658+ val df1 =
659+ spark
660+ .range(0 , 1000 , 1 , numInputPartitions)
661+ .selectExpr(" id % 1 as key1" , " id as value1" )
662+ val df2 =
663+ spark
664+ .range(0 , 10 , 1 , 2 )
665+ .selectExpr(" id % 5 as key2" , " id as value2" )
666+
667+ val leftSemiJoin =
668+ df1.join(df2, col(" key1" ) === col(" key2" ), " left_semi" ).select(col(" key1" ), col(" value1" ))
669+ val leftAntiJoin =
670+ df1.join(df2, col(" key1" ) === col(" key2" ), " left_anti" ).select(col(" key1" ), col(" value1" ))
671+
672+ // Before Execution, there is one SortMergeJoin
673+ val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
674+ case smj : SortMergeJoinExec => smj
675+ }
676+ assert(smjBeforeExecutionForLeftSemi.length === 1 )
677+
678+ val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect {
679+ case smj : SortMergeJoinExec => smj
680+ }
681+ assert(smjBeforeExecutionForLeftAnti.length === 1 )
682+
683+ // Check the answer.
684+ val expectedAnswerForLeftSemi =
685+ spark
686+ .range(0 , 1000 )
687+ .selectExpr(" id % 1 as key" , " id as value" )
688+ checkAnswer(
689+ leftSemiJoin,
690+ expectedAnswerForLeftSemi.collect())
691+
692+ val expectedAnswerForLeftAnti = Seq .empty
693+ checkAnswer(
694+ leftAntiJoin,
695+ expectedAnswerForLeftAnti)
696+
697+ // For the left outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
698+ // joins due to the skewed side is on the left and the join type is left semi
699+ // (correspond with each other)
700+ val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
701+ case smj : SortMergeJoinExec => smj
702+ }
703+ assert(smjAfterExecutionForLeftSemi.length === MAX_SPLIT + 1 )
704+
705+ // For the right outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
706+ // joins due to the skewed side is on the left and the join type is left anti
707+ // (correspond with each other)
708+ val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect {
709+ case smj : SortMergeJoinExec => smj
710+ }
711+ assert(smjAfterExecutionForLeftAnti.length === MAX_SPLIT + 1 )
712+
713+ val queryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect {
714+ case q : ShuffleQueryStageInput => q
715+ }
716+ assert(queryStageInputs.length === 2 )
717+ assert(queryStageInputs(0 ).skewedPartitions === queryStageInputs(1 ).skewedPartitions)
718+ assert(queryStageInputs(0 ).skewedPartitions === Some (Set (0 )))
719+
720+ val skewedQueryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect {
721+ case q : SkewedShuffleQueryStageInput => q
722+ }
723+ assert(skewedQueryStageInputs.length === MAX_SPLIT * 2 )
724+
725+ }
726+ }
727+
581728 test(" row count statistics, compressed" ) {
582729 val spark = defaultSparkSession
583730 withSparkSession(spark) { spark : SparkSession =>
0 commit comments