diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index a276e477c83c4..09bc6412a7f06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.execution -import org.apache.spark.metrics.source.CodegenMetrics -import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed @@ -145,10 +143,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { .select("int") val plan = df.queryExecution.executedPlan - assert(!plan.find(p => + assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.children(0) - .isInstanceOf[SortMergeJoinExec]).isDefined) + .isInstanceOf[SortMergeJoinExec]).isEmpty) assert(df.collect() === Array(Row(1), Row(2))) } } @@ -181,6 +179,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } + def genCode(ds: Dataset[_]): Seq[CodeAndComment] = { + val plan = ds.queryExecution.executedPlan + val wholeStageCodeGenExecs = plan.collect { case p: WholeStageCodegenExec => p } + assert(wholeStageCodeGenExecs.nonEmpty, "WholeStageCodegenExec is expected") + wholeStageCodeGenExecs.map(_.doCodeGen()._2) + } + ignore("SPARK-21871 check if we can get large code size when compiling too long functions") { val codeWithShortFunctions = genGroupByCode(3) val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) @@ -241,9 +246,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val df = spark.range(100) val join = df.join(df, "id") val plan = join.queryExecution.executedPlan - assert(!plan.find(p => + assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined, + p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isEmpty, "codegen stage IDs should be preserved through ReuseExchange") checkAnswer(join, df.toDF) } @@ -253,18 +258,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { import testImplicits._ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { - val bytecodeSizeHisto = CodegenMetrics.METRIC_COMPILATION_TIME - - // the same query run twice should hit the codegen cache - spark.range(3).select('id + 2).collect - val after1 = bytecodeSizeHisto.getCount - spark.range(3).select('id + 2).collect - val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately - // bytecodeSizeHisto's count is always monotonically increasing if new compilation to - // bytecode had occurred. If the count stayed the same that means we've got a cache hit. - assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected") - - // a different query can result in codegen cache miss, that's by design + // the same query run twice should produce identical code + val ds1 = spark.range(3).select('id + 2) + val code1 = genCode(ds1) + val ds2 = spark.range(3).select('id + 2) + val code2 = genCode(ds2) // same query shape as above, deliberately + assert(code1 == code2, "Should produce same code") } }