Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ trait CatalystConf {
def groupByOrdinal: Boolean

def optimizerMaxIterations: Int
def maxCaseBranchesForCodegen: Int

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
Expand All @@ -45,6 +46,7 @@ case class SimpleCatalystConf(
caseSensitiveAnalysis: Boolean,
orderByOrdinal: Boolean = true,
groupByOrdinal: Boolean = true,
optimizerMaxIterations: Int = 100)
optimizerMaxIterations: Int = 100,
maxCaseBranchesForCodegen: Int = 20)
extends CatalystConf {
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}

/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* When a = true, returns b; when c = true, returns d; else returns e.
* Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
// scalastyle:on line.size.limit
case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
extends Expression with CodegenFallback {
abstract class CaseWhenBase(
branches: Seq[(Expression, Expression)],
elseValue: Option[Expression])
extends Expression with Serializable {

override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue

Expand Down Expand Up @@ -142,16 +139,58 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
}
}

def shouldCodegen: Boolean = {
branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN
override def toString: String = {
val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
"CASE" + cases + elseCase + " END"
}

override def sql: String = {
val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
"CASE" + cases + elseCase + " END"
}
}


/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* When a = true, returns b; when c = true, returns d; else returns e.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
// scalastyle:on line.size.limit
case class CaseWhen(
val branches: Seq[(Expression, Expression)],
val elseValue: Option[Expression] = None)
extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just have a toCodegen function that creates CaseWhenCodegen?

We can then remove object CaseWhenCodegen

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be right. CaseWhenCodegen is always generated from CaseWhen.


override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
super[CodegenFallback].doGenCode(ctx, ev)
}

def toCodegen(): CaseWhenCodegen = {
CaseWhenCodegen(branches, elseValue)
}
}

/**
* CaseWhen expression used when code generation condition is satisfied.
* OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
case class CaseWhenCodegen(
val branches: Seq[(Expression, Expression)],
val elseValue: Option[Expression] = None)
extends CaseWhenBase(branches, elseValue) with Serializable {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (!shouldCodegen) {
// Fallback to interpreted mode if there are too many branches, as it may reach the
// 64K limit (limit on bytecode size for a single function).
return super[CodegenFallback].doGenCode(ctx, ev)
}
// Generate code that looks like:
//
// condA = ...
Expand Down Expand Up @@ -202,26 +241,10 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$generatedCode""")
}

override def toString: String = {
val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
"CASE" + cases + elseCase + " END"
}

override def sql: String = {
val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
"CASE" + cases + elseCase + " END"
}
}

/** Factory methods for CaseWhen. */
object CaseWhen {

// The maximum number of switches supported with codegen.
val MAX_NUM_CASES_FOR_CODEGEN = 20

def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
CaseWhen(branches, Option(elseValue))
}
Expand All @@ -242,7 +265,6 @@ object CaseWhen {
}
}


/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
* When a = b, returns c; when a = d, returns e; else returns f.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
Batch("Subquery", Once,
OptimizeSubqueries) :: Nil
OptimizeSubqueries) ::
Batch("OptimizeCodegen", Once,
OptimizeCodegen(conf)) :: Nil
}

/**
Expand Down Expand Up @@ -863,6 +865,16 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* Optimizes expressions by replacing according to CodeGen configuration.
*/
case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen =>
e.toCodegen()
}
}

/**
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class OptimizeCodegenSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil
}

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
comparePlans(actual, correctAnswer)
}

test("Codegen only when the number of branches is small.") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make sure you construct a few more test cases

one with nested casewhen, and one with multiple case when in one operator, and one with multiple casewhen in different operators

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. Sure. I'll add those testcases, too.

assertEquivalent(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())

assertEquivalent(
CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)),
CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)))
}

test("Nested CaseWhen Codegen.") {
assertEquivalent(
CaseWhen(
Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))),
CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
CaseWhen(
Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))),
CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
}

test("Multiple CaseWhen in one operator.") {
val plan = OneRowRelation
.select(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
val correctAnswer = OneRowRelation
.select(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, correctAnswer)
}

test("Multiple CaseWhen in different operators") {
val plan = OneRowRelation
.select(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
.where(
LessThan(
CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
).analyze
val correctAnswer = OneRowRelation
.select(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
.where(
LessThan(
CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {

private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
case e: CaseWhen => e.shouldCodegen
// CodegenFallback requires the input to be an InternalRow
case e: CodegenFallback => false
case _ => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,12 @@ object SQLConf {
.intConf
.createWithDefault(200)

val MAX_CASES_BRANCHES = SQLConfigBuilder("spark.sql.codegen.maxCaseBranches")
.internal()
.doc("The maximum number of switches supported with codegen.")
.intConf
.createWithDefault(20)

val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
.longConf
Expand Down Expand Up @@ -529,6 +535,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)

def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)

def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)

def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW)
Expand Down