Skip to content

Commit 424feb3

Browse files
fenzhuGitHub Enterprise
authored andcommitted
[CARMEL-7546][CARMEL-3523] Optimize skewed insert (apache#324)
1 parent 35fdc95 commit 424feb3

File tree

3 files changed

+301
-0
lines changed

3 files changed

+301
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ case class AdaptiveSparkPlanExec(
131131
CombineAdjacentAggregation,
132132
RemoveRedundantWindowGroupLimits,
133133
DisableUnnecessaryBucketedScan,
134+
OptimizeSkewedInsert,
134135
OptimizeSkewedJoin(ensureRequirements)
135136
) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules
136137
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.commons.io.FileUtils
23+
24+
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
25+
import org.apache.spark.sql.catalyst.rules.Rule
26+
import org.apache.spark.sql.execution._
27+
import org.apache.spark.sql.execution.command.DataWritingCommandExec
28+
import org.apache.spark.sql.execution.datasources.WriteFilesExec
29+
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION_BY_COL, ShuffleExchangeLike}
30+
import org.apache.spark.sql.internal.SQLConf
31+
32+
object OptimizeSkewedInsert extends Rule[SparkPlan] {
33+
34+
private def getSizeInfo(medianSize: Long, sizes: Seq[Long], targetSize: Long): String = {
35+
s"median size: $medianSize, max size: ${sizes.max}, min size: ${sizes.min}, avg size: " +
36+
sizes.sum / sizes.length + s", target size: ${targetSize}"
37+
}
38+
39+
override def apply(plan: SparkPlan): SparkPlan = {
40+
if (!conf.getConf(SQLConf.AUTO_REPARTITION_BEFORE_WRITING_ENABLED)) {
41+
plan
42+
} else {
43+
plan.transformUp {
44+
case w @ DataWritingCommandExec(_, WriteFilesExec(child, _, _, _, _, _))
45+
if supportOptimization(w) => handleSkewed(w, child)
46+
}
47+
}
48+
}
49+
50+
private def handleSkewed(plan: SparkPlan, child: SparkPlan): SparkPlan = {
51+
val (queryStage, planToUpdate) = child match {
52+
case SortExec(_, false, ShuffleStage(s: ShuffleQueryStageExec), _) =>
53+
(Option(s), Option(child))
54+
case ShuffleStage(s: ShuffleQueryStageExec) => (Option(s), None)
55+
case _ => (None, None)
56+
}
57+
if (queryStage.isEmpty || queryStage.get.mapStats.isEmpty ||
58+
!supportOptimizeSkew(queryStage.get.shuffle)) {
59+
plan
60+
} else {
61+
val mapStats = queryStage.get.mapStats.get
62+
val sizes = mapStats.bytesByPartitionId
63+
val numPartitions = sizes.length
64+
// We use the median size of the original shuffle partitions to detect skewed partitions.
65+
val medSize = SkewHandlingUtil.medianSize(mapStats)
66+
val targetSize = SkewHandlingUtil.targetSize(sizes, medSize, conf)
67+
logInfo(
68+
s"""
69+
|Optimizing skewed insert, partition size info:
70+
|${getSizeInfo(medSize, mapStats.bytesByPartitionId, targetSize)}
71+
""".stripMargin)
72+
73+
var numSkewed = 0
74+
val shufflePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
75+
for (partitionId <- 0 until numPartitions) {
76+
val size = sizes(partitionId)
77+
val isSkew = SkewHandlingUtil.isSkewed(size, medSize, conf)
78+
val partSpec = CoalescedPartitionSpec(partitionId, partitionId + 1, size)
79+
val isCoalesced = partSpec.startReducerIndex + 1 < partSpec.endReducerIndex
80+
81+
// A skewed partition should never be coalesced, but skip it here just to be safe.
82+
val parts = if (isSkew && !isCoalesced) {
83+
val reducerId = partSpec.startReducerIndex
84+
val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs(
85+
mapStats.shuffleId, reducerId, targetSize)
86+
if (skewSpecs.isDefined) {
87+
logInfo(s"Partition $partitionId " +
88+
s"(${FileUtils.byteCountToDisplaySize(size)}) is skewed, " +
89+
s"split it into ${skewSpecs.get.length} parts.")
90+
numSkewed += 1
91+
}
92+
skewSpecs.getOrElse(Seq(partSpec))
93+
} else {
94+
Seq(partSpec)
95+
}
96+
for (shufflePartition <- parts) {
97+
shufflePartitions += shufflePartition
98+
}
99+
}
100+
101+
logInfo(s"number of skewed partitions: $numSkewed")
102+
if (numSkewed > 0) {
103+
val newShuffleReader = AQEShuffleReadExec(
104+
queryStage.get, shufflePartitions)
105+
val newChild = planToUpdate match {
106+
case Some(p) => p.withNewChildren(Seq(newShuffleReader))
107+
case _ => newShuffleReader
108+
}
109+
plan.withNewChildren(newChild :: Nil)
110+
} else {
111+
plan
112+
}
113+
}
114+
}
115+
116+
private def supportOptimizeSkew(s: ShuffleExchangeLike): Boolean = {
117+
s.shuffleOrigin == REPARTITION_BY_COL || s.shuffleOrigin == ENSURE_REQUIREMENTS
118+
}
119+
120+
private def supportOptimization(plan: SparkPlan): Boolean = {
121+
plan.requiredChildDistribution.forall {
122+
case UnspecifiedDistribution => true
123+
case _ => false
124+
}
125+
}
126+
127+
private object ShuffleStage {
128+
def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
129+
case s: ShuffleQueryStageExec if s.isMaterialized && s.mapStats.isDefined =>
130+
Some(s)
131+
case _ => None
132+
}
133+
}
134+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3046,6 +3046,172 @@ class AdaptiveQueryExecSuite
30463046
}
30473047
}
30483048
}
3049+
3050+
def checkSkewInsert(plan: SparkPlan, expectedSkewPartitions: Int): Unit = {
3051+
val reader = plan.collect {
3052+
case r: AQEShuffleReadExec => r
3053+
}.head
3054+
assert(reader.hasSkewedPartition)
3055+
// assert(reader.hasCoalescedPartition) // 0-size partitions are ignored.
3056+
val numSkewedPartitions = reader.partitionSpecs.collect {
3057+
case p: PartialReducerPartitionSpec => p.reducerIndex
3058+
}.distinct.length
3059+
assert(numSkewedPartitions == expectedSkewPartitions)
3060+
}
3061+
3062+
protected def getCorePlan(plan: SparkPlan): SparkPlan = {
3063+
plan match {
3064+
case org.apache.spark.sql.execution.CommandResultExec(_, child, _) =>
3065+
getCorePlan(child)
3066+
case ae: AdaptiveSparkPlanExec => ae.finalPhysicalPlan
3067+
case _ => plan
3068+
}
3069+
}
3070+
3071+
protected def stripCommandResultExec(plan: SparkPlan): SparkPlan = {
3072+
plan match {
3073+
case org.apache.spark.sql.execution.CommandResultExec(_, child, _) => child
3074+
case _ => plan
3075+
}
3076+
}
3077+
3078+
test("adaptive skewed insert: create as select command") {
3079+
withTable("tbl", "tbl2") {
3080+
withSQLConf(
3081+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
3082+
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true",
3083+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
3084+
SQLConf.AUTO_REPARTITION_BEFORE_WRITING_ENABLED.key -> "true",
3085+
SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "100",
3086+
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100") {
3087+
3088+
spark
3089+
.range(0, 1000, 1, 10)
3090+
.selectExpr("id % 1 as key", "id as value")
3091+
.write.saveAsTable("tbl")
3092+
3093+
val listener = new QueryExecutionListener {
3094+
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
3095+
val plan = stripCommandResultExec(qe.executedPlan)
3096+
plan match {
3097+
case ae: AdaptiveSparkPlanExec =>
3098+
val queryStages = ae.finalPhysicalPlan.collect {
3099+
case qs: ShuffleQueryStageExec => qs
3100+
}
3101+
assert(queryStages.length == 1)
3102+
checkSkewInsert(ae.finalPhysicalPlan, 1)
3103+
case _ =>
3104+
}
3105+
}
3106+
override def onFailure(funcName: String, qe: QueryExecution,
3107+
exception: Exception): Unit = {}
3108+
}
3109+
spark.listenerManager.register(listener)
3110+
spark.sql("create table tbl2 using parquet " +
3111+
"partitioned by (key) select * from tbl")
3112+
spark.listenerManager.unregister(listener)
3113+
assert(sql("select count(*) from tbl2").collect().head.getLong(0) == 1000)
3114+
}
3115+
}
3116+
}
3117+
3118+
test("adaptive skewed insert: insert into command") {
3119+
withTable("tbl", "tbl2") {
3120+
withSQLConf(
3121+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
3122+
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true",
3123+
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
3124+
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
3125+
SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic",
3126+
SQLConf.AUTO_REPARTITION_BEFORE_WRITING_ENABLED.key -> "true") {
3127+
3128+
spark
3129+
.range(0, 1000, 1, 10)
3130+
.selectExpr("id % 1 as key", "id % 1 as value")
3131+
.write.saveAsTable("tbl")
3132+
spark.sql("create table tbl2(key int, value int) using parquet " +
3133+
"partitioned by (key)")
3134+
val df2 = spark.sql("insert overwrite table tbl2 partition(key) select * from tbl")
3135+
val qe2 = df2.queryExecution
3136+
val plan = getCorePlan(qe2.sparkPlan)
3137+
val writeOps = plan.collect {
3138+
case w: DataWritingCommandExec => w
3139+
}
3140+
assert(writeOps.size == 1)
3141+
val queryStages = plan.collect {
3142+
case qs: ShuffleQueryStageExec => qs
3143+
}
3144+
assert(queryStages.length == 1)
3145+
checkSkewInsert(plan, 1)
3146+
3147+
assert(sql("select count(*) from tbl2").collect().head.getLong(0) == 1000)
3148+
}
3149+
}
3150+
}
3151+
3152+
test("CARMEL-2389 adaptive skewed insert: ArrayIndexOutOfBoundsException exception") {
3153+
withTable("tbl", "tbl2") {
3154+
withSQLConf(
3155+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
3156+
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true",
3157+
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
3158+
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
3159+
SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic",
3160+
SQLConf.AUTO_REPARTITION_BEFORE_WRITING_ENABLED.key -> "true",
3161+
SQLConf.SHUFFLE_PARTITIONS.key -> "8") {
3162+
3163+
spark
3164+
.range(0, 1000, 1, 10)
3165+
.selectExpr("id % 3 as key", "id % 1 as value")
3166+
.write.saveAsTable("tbl")
3167+
spark.sql("create table tbl2(key int, value int) using parquet " +
3168+
"partitioned by (key)")
3169+
val df2 = spark.sql("insert overwrite table tbl2 partition(key) select * from tbl")
3170+
val qe2 = df2.queryExecution
3171+
val plan = getCorePlan(qe2.sparkPlan)
3172+
val writeOps = plan.collect {
3173+
case w: DataWritingCommandExec => w
3174+
}
3175+
assert(writeOps.size == 1)
3176+
val queryStages = plan.collect {
3177+
case qs: ShuffleQueryStageExec => qs
3178+
}
3179+
assert(queryStages.length == 1)
3180+
assert(sql("select count(*) from tbl2").collect().head.getLong(0) == 1000)
3181+
}
3182+
}
3183+
}
3184+
3185+
test("adaptive skewed insert: insert into command, source table is empty") {
3186+
withTable("tbl", "tbl2") {
3187+
withSQLConf(
3188+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
3189+
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true",
3190+
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
3191+
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
3192+
SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic",
3193+
SQLConf.AUTO_REPARTITION_BEFORE_WRITING_ENABLED.key -> "true") {
3194+
3195+
spark.sql("create table tbl(key int, value int) using parquet " +
3196+
"partitioned by (key)")
3197+
spark.sql("create table tbl2(key int, value int) using parquet " +
3198+
"partitioned by (key)")
3199+
val df2 = spark.sql("insert overwrite table tbl2 partition(key) select * from tbl")
3200+
val qe2 = df2.queryExecution
3201+
val plan = getCorePlan(qe2.sparkPlan)
3202+
val writeOps = plan.collect {
3203+
case w: DataWritingCommandExec => w
3204+
}
3205+
assert(writeOps.size == 1)
3206+
3207+
val queryStages = plan.collect {
3208+
case qs: ShuffleQueryStageExec => qs
3209+
}
3210+
assert(queryStages.isEmpty)
3211+
assert(sql("select count(*) from tbl2").collect().head.getLong(0) == 0)
3212+
}
3213+
}
3214+
}
30493215
}
30503216

30513217
/**

0 commit comments

Comments
 (0)