Skip to content

Commit 51ef443

Browse files
maropudongjoon-hyun
authored andcommitted
[SPARK-33822][SQL] Use the CastSupport.cast method in HashJoin
### What changes were proposed in this pull request? This PR intends to fix the bug that throws a unsupported exception when running [the TPCDS q5](https://github.com/apache/spark/blob/master/sql/core/src/test/resources/tpcds/q5.sql) with AQE enabled ([this option is enabled by default now via SPARK-33679](031c5ef)): ``` java.lang.UnsupportedOperationException: BroadcastExchange does not support the execute() code path. at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecute(BroadcastExchangeExec.scala:189) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176) at org.apache.spark.sql.execution.exchange.ReusedExchangeExec.doExecute(Exchange.scala:60) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176) at org.apache.spark.sql.execution.adaptive.QueryStageExec.doExecute(QueryStageExec.scala:115) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176) at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:321) at org.apache.spark.sql.execution.SparkPlan.executeCollectIterator(SparkPlan.scala:397) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.$anonfun$relationFuture$1(BroadcastExchangeExec.scala:118) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$1(SQLExecution.scala:185) at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264) ... ``` I've checked the AQE code and I found `EnsureRequirements` wrongly puts `BroadcastExchange` on a top of `BroadcastQueryStage` in the `reOptimize` phase as follows: ``` +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#2183] +- BroadcastQueryStage 2 +- ReusedExchange [d_date_sk#1086], BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#1963] ``` A root cause is that a `Cast` class in a required child's distribution does not have a `timeZoneId` field (`timeZoneId=None`), and a `Cast` class in `child.outputPartitioning` has it. So, this difference can make the distribution requirement check fail in `EnsureRequirements`: https://github.com/apache/spark/blob/1e85707738a830d33598ca267a6740b3f06b1861/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala#L47-L50 The `Cast` class that does not have a `timeZoneId` field is generated in the `HashJoin` object. To fix this issue, this PR proposes to use the `CastSupport.cast` method there. ### Why are the changes needed? Bugfix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually checked that q5 passed. Closes #30818 from maropu/BugfixInAQE. Authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 15616f4 commit 51ef443

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20-
import org.apache.spark.sql.catalyst.InternalRow
20+
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
21+
import org.apache.spark.sql.catalyst.analysis.CastSupport
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
2324
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -756,7 +757,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
756757
protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo
757758
}
758759

759-
object HashJoin {
760+
object HashJoin extends CastSupport with SQLConfHelper {
760761
/**
761762
* Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
762763
*
@@ -771,14 +772,14 @@ object HashJoin {
771772
}
772773

773774
var keyExpr: Expression = if (keys.head.dataType != LongType) {
774-
Cast(keys.head, LongType)
775+
cast(keys.head, LongType)
775776
} else {
776777
keys.head
777778
}
778779
keys.tail.foreach { e =>
779780
val bits = e.dataType.defaultSize * 8
780781
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
781-
BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
782+
BitwiseAnd(cast(e, LongType), Literal((1L << bits) - 1)))
782783
}
783784
keyExpr :: Nil
784785
}
@@ -791,13 +792,13 @@ object HashJoin {
791792
// jump over keys that have a higher index value than the required key
792793
if (keys.size == 1) {
793794
assert(index == 0)
794-
Cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
795+
cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
795796
} else {
796797
val shiftedBits =
797798
keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum
798799
val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1
799800
// build the schema for unpacking the required key
800-
Cast(BitwiseAnd(
801+
cast(BitwiseAnd(
801802
ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)),
802803
Literal(mask)), keys(index).dataType)
803804
}

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -242,33 +242,40 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
242242
assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
243243
assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)
244244

245-
assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil)
245+
assert(HashJoin.rewriteKeyExpr(i :: Nil) ===
246+
Cast(i, LongType, Some(conf.sessionLocalTimeZone)) :: Nil)
246247
assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
247248
assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
248-
BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)),
249-
BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil)
249+
BitwiseOr(ShiftLeft(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal(32)),
250+
BitwiseAnd(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 32) - 1))) ::
251+
Nil)
250252
assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)
251253

252-
assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil)
254+
assert(HashJoin.rewriteKeyExpr(s :: Nil) ===
255+
Cast(s, LongType, Some(conf.sessionLocalTimeZone)) :: Nil)
253256
assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
254257
assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
255-
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
256-
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
258+
BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
259+
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
260+
Nil)
257261
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
258262
BitwiseOr(ShiftLeft(
259-
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
260-
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
263+
BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
264+
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))),
261265
Literal(16)),
262-
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
266+
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
267+
Nil)
263268
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
264269
BitwiseOr(ShiftLeft(
265270
BitwiseOr(ShiftLeft(
266-
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
267-
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
271+
BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
272+
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)),
273+
Literal((1L << 16) - 1))),
268274
Literal(16)),
269-
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
275+
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))),
270276
Literal(16)),
271-
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
277+
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
278+
Nil)
272279
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
273280
s :: s :: s :: s :: s :: Nil)
274281

0 commit comments

Comments
 (0)