Skip to content

Commit 5035f66

Browse files
xingchaozhGitHub Enterprise
authored andcommitted
[CARMEL-5912] Support PARALLEL Hint against Delta Table (#935)
1 parent 0dd2021 commit 5035f66

File tree

7 files changed

+150
-19
lines changed

7 files changed

+150
-19
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningContext.scala

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ class QueryPlanningContext {
2929
new java.util.concurrent.ConcurrentHashMap[String, TableParallelInfo]
3030

3131
def clear(): Unit = tableParallelHintMap.clear()
32+
33+
def merge(context: Option[QueryPlanningContext]): Unit = {
34+
context.foreach(p => {
35+
tableParallelHintMap.putAll(p.tableParallelHintMap)
36+
})
37+
}
3238
}
3339

3440
object QueryPlanningContext {
@@ -38,10 +44,13 @@ object QueryPlanningContext {
3844
.recordStats()
3945
.build()
4046

41-
def putIfAbsent(executionId: String, context: QueryPlanningContext): Unit = {
42-
if (StringUtils.isNotBlank(executionId) &&
43-
planningContextCache.getIfPresent(executionId) == null) {
44-
planningContextCache.put(executionId, context)
47+
def putIfAbsentOrMerge(executionId: String, context: QueryPlanningContext): Unit = {
48+
if (StringUtils.isNotBlank(executionId)) {
49+
if (getIfPresent(executionId).isEmpty) {
50+
planningContextCache.put(executionId, context)
51+
} else {
52+
merge(executionId, Some(context))
53+
}
4554
}
4655
}
4756

@@ -52,9 +61,18 @@ object QueryPlanningContext {
5261
}
5362

5463
def invalidate(executionId: String): Unit = {
55-
getIfPresent(executionId).foreach(_.clear())
5664
if (StringUtils.isNotBlank(executionId)) {
65+
getIfPresent(executionId).foreach(_.clear())
5766
planningContextCache.invalidate(executionId)
5867
}
5968
}
69+
70+
def merge(executionId: String, context: Option[QueryPlanningContext]): Unit = {
71+
if (StringUtils.isNotBlank(executionId)) {
72+
if (getIfPresent(executionId).isEmpty) {
73+
planningContextCache.put(executionId, new QueryPlanningContext)
74+
}
75+
getIfPresent(executionId).foreach(_.merge(context))
76+
}
77+
}
6078
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ object QueryPlanningTracker {
8383
}
8484

8585
/** Returns the current tracker in scope, based on the thread local variable. */
86-
def get: Option[QueryPlanningTracker] = Option(localTracker.get())
86+
def getCurrent: Option[QueryPlanningTracker] = Option(localTracker.get())
8787

8888
/** Sets the current tracker for the execution of function f. We assume f is single-threaded. */
8989
def withTracker[T](tracker: QueryPlanningTracker)(f: => T): T = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
128128
var curPlan = plan
129129
val queryExecutionMetrics = RuleExecutor.queryExecutionMeter
130130
val planChangeLogger = new PlanChangeLogger()
131-
val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.get
131+
val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.getCurrent
132132
val beforeMetrics = RuleExecutor.getCurrentMetrics()
133133

134134
// Run the structural integrity checker against the initial input

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.SparkContext
2424
import org.apache.spark.deploy.SparkHadoopUtil
2525
import org.apache.spark.internal.config.Tests.IS_TESTING
2626
import org.apache.spark.sql.SparkSession
27+
import org.apache.spark.sql.catalyst.QueryPlanningContext
2728
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
2829
import org.apache.spark.sql.internal.StaticSQLConf.SQL_EVENT_TRUNCATE_LENGTH
2930
import org.apache.spark.util.Utils
@@ -62,15 +63,28 @@ object SQLExecution {
6263
* Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that
6364
* we can connect them with an execution.
6465
*/
66+
def withNewExecutionId[T](
67+
queryExecution: QueryExecution,
68+
name: Option[String] = None)(body: => T): T = {
69+
withNewExecutionId(queryExecution, name, false)(body)
70+
}
71+
6572
def withNewExecutionId[T](
6673
queryExecution: QueryExecution,
67-
name: Option[String] = None)(body: => T): T = queryExecution.sparkSession.withActive {
74+
name: Option[String],
75+
inheritParentQueryPlanningContext: Boolean)(body: => T): T =
76+
queryExecution.sparkSession.withActive {
6877
val sparkSession = queryExecution.sparkSession
6978
val sc = sparkSession.sparkContext
7079
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
7180
val executionId = SQLExecution.nextExecutionId
7281
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
7382
executionIdToQueryExecution.put(executionId, queryExecution)
83+
84+
if (inheritParentQueryPlanningContext) {
85+
val parentQueryPlanningContext = QueryPlanningContext.getIfPresent(oldExecutionId)
86+
QueryPlanningContext.merge(executionId.toString, parentQueryPlanningContext)
87+
}
7488
try {
7589
// sparkContext.getCallSite() would first try to pick up any call site that was previously
7690
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CacheParallelHint.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,24 @@ private[sql] object CacheParallelHint
5151
if (f.partitionSize.nonEmpty || f.partitionNumber.nonEmpty) => f
5252
}.foreach(r => {
5353
logicalRelation.catalogTable.foreach(catalogTable => {
54-
val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.get
55-
tracker.foreach(f => {
56-
SparkContext.getActive.foreach(sc => {
57-
val tableParallelInfo =
58-
TableParallelInfo(catalogTable.identifier.toString, r.partitionSize,
54+
SparkContext.getActive.foreach(sc => {
55+
val tableParallelInfo =
56+
TableParallelInfo(catalogTable.identifier.toString, r.partitionSize,
5957
r.partitionNumber)
6058

61-
f.queryPlanningContext.tableParallelHintMap.
62-
put(catalogTable.identifier.toString, tableParallelInfo)
59+
val queryPlanningContext = if (QueryPlanningTracker.getCurrent.isEmpty) {
60+
new QueryPlanningContext
61+
} else {
62+
QueryPlanningTracker.getCurrent.get.queryPlanningContext
63+
}
6364

64-
val executionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
65-
QueryPlanningContext.putIfAbsent(executionId, f.queryPlanningContext)
66-
})
65+
queryPlanningContext.tableParallelHintMap.
66+
put(catalogTable.identifier.toString, tableParallelInfo)
67+
68+
val executionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
69+
QueryPlanningContext.putIfAbsentOrMerge(executionId, queryPlanningContext)
6770
})
71+
6872
})
6973
})
7074

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ object FileSourceStrategy extends Strategy with SQLConfHelper with Logging {
227227
})
228228

229229
if (tableParallelInfo.isEmpty) {
230-
QueryPlanningTracker.get.foreach(f => {
230+
QueryPlanningTracker.getCurrent.foreach(f => {
231231
tableParallelInfo = Option(
232232
f.queryPlanningContext.tableParallelHintMap.get(tableIdentifier)
233233
)

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/delta/DeltaQuerySuite.scala

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3164,6 +3164,101 @@ class DeltaQuerySuite extends QueryTest
31643164
}
31653165
}
31663166

3167+
test("Bucket scan in delta CTAS disabled if PARALLEL HINT detected.") {
3168+
checkBucketingResultInDeltaCTASIfParallelDetected("/*+ PARALLEL(bucket_table, 100) */",
3169+
targetPartitionNumber = 100, parallelHintEnabled = true,
3170+
viewName = "bucket_table_v_with_hint")
3171+
}
3172+
3173+
test("Bucket scan in delta CTAS enabled if PARALLEL HINT not detected.") {
3174+
checkBucketingResultInDeltaCTASIfParallelDetected("",
3175+
targetPartitionNumber = 100, parallelHintEnabled = true,
3176+
viewName = "bucket_table_v_without_hint")
3177+
}
3178+
3179+
def checkBucketingResultInDeltaCTASIfParallelDetected(hint: String,
3180+
targetPartitionNumber: Long = 1024,
3181+
parallelHintEnabled: Boolean = true,
3182+
viewName: String): Unit = {
3183+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
3184+
SQLConf.RANGE_JOIN_ENABLED.key -> "false", // Disable range join for testing
3185+
SQLConf.PARALLEL_HINT_ENABLED.key -> parallelHintEnabled.toString,
3186+
SQLConf.SQL_UI_PLAN_WITH_METRICS.key -> "false") {
3187+
withTable("bucket_table", "a1") {
3188+
3189+
withView(viewName) {
3190+
sql(
3191+
s"""
3192+
|CREATE TABLE bucket_table(id string, ip_add_int long, dt string) USING PARQUET
3193+
|PARTITIONED BY (dt)
3194+
|CLUSTERED BY (ip_add_int)
3195+
|INTO 10 BUCKETS
3196+
|
3197+
|""".stripMargin)
3198+
3199+
sql(
3200+
s"""
3201+
|insert into bucket_table PARTITION (dt = "2021-12-20")
3202+
|select 123 as id, 456 as ip_add_int
3203+
""".stripMargin
3204+
)
3205+
3206+
sql(
3207+
s"""
3208+
|create view ${viewName}(id, ip_add_int, dt) as
3209+
|select id, ip_add_int, dt from bucket_table
3210+
""".stripMargin
3211+
)
3212+
3213+
val df = sql(
3214+
s"""
3215+
|create temp table a1 using delta AS
3216+
|SELECT ${hint} ip_add_int, count(*) as cnt FROM (
3217+
|SELECT ip_add_int, count(*) as c FROM (
3218+
|SELECT ip_add_int, count(*) FROM ${viewName} ev3
3219+
|WHERE ev3.id !=0 AND ev3.dt>="2021-12-20" group by 1
3220+
|) ev2 group by 1
3221+
|) ev group by 1
3222+
""".stripMargin
3223+
)
3224+
3225+
sql(
3226+
s"""
3227+
|select * from a1
3228+
""".stripMargin
3229+
).show()
3230+
3231+
// scalastyle:off println
3232+
println(s"logical plan: ${df.queryExecution.logical}")
3233+
println(s"optimized plan: ${df.queryExecution.optimizedPlan}")
3234+
println(df.collect().length)
3235+
// scalastyle:on println
3236+
3237+
// checkRangeJoinResultInner(df, buildSide, false, 0)
3238+
val planAfter = df.queryExecution.executedPlan
3239+
// scalastyle:off println
3240+
println(s"planAfter: ${planAfter}")
3241+
// scalastyle:on println
3242+
val scan = collectFirst(planAfter) {
3243+
case f: FileSourceScanExec => f
3244+
}
3245+
assert(scan.isDefined)
3246+
assert(scan.get.dataFilters.nonEmpty)
3247+
3248+
if (hint.contains("PARALLEL") && parallelHintEnabled) {
3249+
assert(scan.get.tableParallelInfo.isDefined)
3250+
assert(scan.get.tableParallelInfo.get.partitionNumber.get == targetPartitionNumber)
3251+
assert(scan.get.bucketedScan == false)
3252+
} else {
3253+
assert(scan.get.tableParallelInfo.isEmpty)
3254+
assert(scan.get.bucketedScan == true)
3255+
}
3256+
}
3257+
}
3258+
}
3259+
}
3260+
3261+
31673262
test("test analyze delta table") {
31683263
def getCatalogStatistics(tableName: String): CatalogStatistics = {
31693264
getCatalogTable(TableIdentifier(tableName)).stats.get

0 commit comments

Comments
 (0)