@@ -3046,6 +3046,172 @@ class AdaptiveQueryExecSuite
30463046 }
30473047 }
30483048 }
3049+
3050+ def checkSkewInsert (plan : SparkPlan , expectedSkewPartitions : Int ): Unit = {
3051+ val reader = plan.collect {
3052+ case r : AQEShuffleReadExec => r
3053+ }.head
3054+ assert(reader.hasSkewedPartition)
3055+ // assert(reader.hasCoalescedPartition) // 0-size partitions are ignored.
3056+ val numSkewedPartitions = reader.partitionSpecs.collect {
3057+ case p : PartialReducerPartitionSpec => p.reducerIndex
3058+ }.distinct.length
3059+ assert(numSkewedPartitions == expectedSkewPartitions)
3060+ }
3061+
3062+ protected def getCorePlan (plan : SparkPlan ): SparkPlan = {
3063+ plan match {
3064+ case org.apache.spark.sql.execution.CommandResultExec (_, child, _) =>
3065+ getCorePlan(child)
3066+ case ae : AdaptiveSparkPlanExec => ae.finalPhysicalPlan
3067+ case _ => plan
3068+ }
3069+ }
3070+
3071+ protected def stripCommandResultExec (plan : SparkPlan ): SparkPlan = {
3072+ plan match {
3073+ case org.apache.spark.sql.execution.CommandResultExec (_, child, _) => child
3074+ case _ => plan
3075+ }
3076+ }
3077+
3078+ test(" adaptive skewed insert: create as select command" ) {
3079+ withTable(" tbl" , " tbl2" ) {
3080+ withSQLConf(
3081+ SQLConf .ADAPTIVE_EXECUTION_ENABLED .key -> " true" ,
3082+ SQLConf .ADAPTIVE_EXECUTION_FORCE_APPLY .key -> " true" ,
3083+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
3084+ SQLConf .AUTO_REPARTITION_BEFORE_WRITING_ENABLED .key -> " true" ,
3085+ SQLConf .SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE .key -> " 100" ,
3086+ SQLConf .SKEW_JOIN_SKEWED_PARTITION_THRESHOLD .key -> " 100" ) {
3087+
3088+ spark
3089+ .range(0 , 1000 , 1 , 10 )
3090+ .selectExpr(" id % 1 as key" , " id as value" )
3091+ .write.saveAsTable(" tbl" )
3092+
3093+ val listener = new QueryExecutionListener {
3094+ override def onSuccess (funcName : String , qe : QueryExecution , durationNs : Long ): Unit = {
3095+ val plan = stripCommandResultExec(qe.executedPlan)
3096+ plan match {
3097+ case ae : AdaptiveSparkPlanExec =>
3098+ val queryStages = ae.finalPhysicalPlan.collect {
3099+ case qs : ShuffleQueryStageExec => qs
3100+ }
3101+ assert(queryStages.length == 1 )
3102+ checkSkewInsert(ae.finalPhysicalPlan, 1 )
3103+ case _ =>
3104+ }
3105+ }
3106+ override def onFailure (funcName : String , qe : QueryExecution ,
3107+ exception : Exception ): Unit = {}
3108+ }
3109+ spark.listenerManager.register(listener)
3110+ spark.sql(" create table tbl2 using parquet " +
3111+ " partitioned by (key) select * from tbl" )
3112+ spark.listenerManager.unregister(listener)
3113+ assert(sql(" select count(*) from tbl2" ).collect().head.getLong(0 ) == 1000 )
3114+ }
3115+ }
3116+ }
3117+
3118+ test(" adaptive skewed insert: insert into command" ) {
3119+ withTable(" tbl" , " tbl2" ) {
3120+ withSQLConf(
3121+ SQLConf .ADAPTIVE_EXECUTION_ENABLED .key -> " true" ,
3122+ SQLConf .ADAPTIVE_EXECUTION_FORCE_APPLY .key -> " true" ,
3123+ SQLConf .ADVISORY_PARTITION_SIZE_IN_BYTES .key -> " 100" ,
3124+ SQLConf .SKEW_JOIN_SKEWED_PARTITION_THRESHOLD .key -> " 100" ,
3125+ SQLConf .PARTITION_OVERWRITE_MODE .key -> " dynamic" ,
3126+ SQLConf .AUTO_REPARTITION_BEFORE_WRITING_ENABLED .key -> " true" ) {
3127+
3128+ spark
3129+ .range(0 , 1000 , 1 , 10 )
3130+ .selectExpr(" id % 1 as key" , " id % 1 as value" )
3131+ .write.saveAsTable(" tbl" )
3132+ spark.sql(" create table tbl2(key int, value int) using parquet " +
3133+ " partitioned by (key)" )
3134+ val df2 = spark.sql(" insert overwrite table tbl2 partition(key) select * from tbl" )
3135+ val qe2 = df2.queryExecution
3136+ val plan = getCorePlan(qe2.sparkPlan)
3137+ val writeOps = plan.collect {
3138+ case w : DataWritingCommandExec => w
3139+ }
3140+ assert(writeOps.size == 1 )
3141+ val queryStages = plan.collect {
3142+ case qs : ShuffleQueryStageExec => qs
3143+ }
3144+ assert(queryStages.length == 1 )
3145+ checkSkewInsert(plan, 1 )
3146+
3147+ assert(sql(" select count(*) from tbl2" ).collect().head.getLong(0 ) == 1000 )
3148+ }
3149+ }
3150+ }
3151+
3152+ test(" CARMEL-2389 adaptive skewed insert: ArrayIndexOutOfBoundsException exception" ) {
3153+ withTable(" tbl" , " tbl2" ) {
3154+ withSQLConf(
3155+ SQLConf .ADAPTIVE_EXECUTION_ENABLED .key -> " true" ,
3156+ SQLConf .ADAPTIVE_EXECUTION_FORCE_APPLY .key -> " true" ,
3157+ SQLConf .ADVISORY_PARTITION_SIZE_IN_BYTES .key -> " 100" ,
3158+ SQLConf .SKEW_JOIN_SKEWED_PARTITION_THRESHOLD .key -> " 100" ,
3159+ SQLConf .PARTITION_OVERWRITE_MODE .key -> " dynamic" ,
3160+ SQLConf .AUTO_REPARTITION_BEFORE_WRITING_ENABLED .key -> " true" ,
3161+ SQLConf .SHUFFLE_PARTITIONS .key -> " 8" ) {
3162+
3163+ spark
3164+ .range(0 , 1000 , 1 , 10 )
3165+ .selectExpr(" id % 3 as key" , " id % 1 as value" )
3166+ .write.saveAsTable(" tbl" )
3167+ spark.sql(" create table tbl2(key int, value int) using parquet " +
3168+ " partitioned by (key)" )
3169+ val df2 = spark.sql(" insert overwrite table tbl2 partition(key) select * from tbl" )
3170+ val qe2 = df2.queryExecution
3171+ val plan = getCorePlan(qe2.sparkPlan)
3172+ val writeOps = plan.collect {
3173+ case w : DataWritingCommandExec => w
3174+ }
3175+ assert(writeOps.size == 1 )
3176+ val queryStages = plan.collect {
3177+ case qs : ShuffleQueryStageExec => qs
3178+ }
3179+ assert(queryStages.length == 1 )
3180+ assert(sql(" select count(*) from tbl2" ).collect().head.getLong(0 ) == 1000 )
3181+ }
3182+ }
3183+ }
3184+
3185+ test(" adaptive skewed insert: insert into command, source table is empty" ) {
3186+ withTable(" tbl" , " tbl2" ) {
3187+ withSQLConf(
3188+ SQLConf .ADAPTIVE_EXECUTION_ENABLED .key -> " true" ,
3189+ SQLConf .ADAPTIVE_EXECUTION_FORCE_APPLY .key -> " true" ,
3190+ SQLConf .ADVISORY_PARTITION_SIZE_IN_BYTES .key -> " 100" ,
3191+ SQLConf .SKEW_JOIN_SKEWED_PARTITION_THRESHOLD .key -> " 100" ,
3192+ SQLConf .PARTITION_OVERWRITE_MODE .key -> " dynamic" ,
3193+ SQLConf .AUTO_REPARTITION_BEFORE_WRITING_ENABLED .key -> " true" ) {
3194+
3195+ spark.sql(" create table tbl(key int, value int) using parquet " +
3196+ " partitioned by (key)" )
3197+ spark.sql(" create table tbl2(key int, value int) using parquet " +
3198+ " partitioned by (key)" )
3199+ val df2 = spark.sql(" insert overwrite table tbl2 partition(key) select * from tbl" )
3200+ val qe2 = df2.queryExecution
3201+ val plan = getCorePlan(qe2.sparkPlan)
3202+ val writeOps = plan.collect {
3203+ case w : DataWritingCommandExec => w
3204+ }
3205+ assert(writeOps.size == 1 )
3206+
3207+ val queryStages = plan.collect {
3208+ case qs : ShuffleQueryStageExec => qs
3209+ }
3210+ assert(queryStages.isEmpty)
3211+ assert(sql(" select count(*) from tbl2" ).collect().head.getLong(0 ) == 0 )
3212+ }
3213+ }
3214+ }
30493215}
30503216
30513217/**
0 commit comments