Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.{InternalRow, trees}

/**
* A bound reference points to a specific slot in the input tuple, allowing the actual value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ abstract class Expression extends TreeNode[Expression] {
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: InternalRow = null): Any

/**
* Return true if this expression is thread-safe, which means it could be used by multiple
* threads in the same time.
*
* An expression that is not thread-safe can not be cached and re-used, especially for codegen.
*/
def isThreadSafe: Boolean = true

/**
* Returns an [[GeneratedExpressionCode]], which contains Java source code that
* can be used to generate the result of evaluating the expression on an input row.
Expand All @@ -69,6 +77,9 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
if (!isThreadSafe) {
throw new Exception(s"$this is not thread-safe, can not be used in codegen")
}
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
Expand Down Expand Up @@ -170,6 +181,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express

override def toString: String = s"($left $symbol $right)"

override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe
/**
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
* the same type. If either of the sub-expressions is null, the result of this computation
Expand Down Expand Up @@ -219,6 +231,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio

override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
override def isThreadSafe: Boolean = child.isThreadSafe

/**
* Called by unary expressions to generate a code block that returns null if its parent returns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -958,4 +958,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
override def eval(input: InternalRow): Any = converter(f(input))

// TODO(davies): make ScalaUdf work with codegen
override def isThreadSafe: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.{InternalRow, trees}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types.DataType

abstract sealed class SortDirection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,31 +348,29 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isNativeType(left.dataType)) {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
eval1.code + eval2.code + s"""
boolean ${ev.isNull} = false;
${ctx.javaType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};

if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive)

eval1.code + eval2.code + s"""
boolean ${ev.isNull} = false;
${ctx.javaType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};

if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
${ev.primitive} = ${eval1.primitive};
} else {
if ($compCode > 0) {
${ev.primitive} = ${eval1.primitive};
} else {
if (${eval1.primitive} > ${eval2.primitive}) {
${ev.primitive} = ${eval1.primitive};
} else {
${ev.primitive} = ${eval2.primitive};
}
${ev.primitive} = ${eval2.primitive};
}
"""
} else {
super.genCode(ctx, ev)
}
}
"""
}
override def toString: String = s"MaxOf($left, $right)"
}
Expand Down Expand Up @@ -402,33 +400,29 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isNativeType(left.dataType)) {

val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)

eval1.code + eval2.code + s"""
boolean ${ev.isNull} = false;
${ctx.javaType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive)

if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
eval1.code + eval2.code + s"""
boolean ${ev.isNull} = false;
${ctx.javaType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};

if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
${ev.primitive} = ${eval1.primitive};
} else {
if ($compCode < 0) {
${ev.primitive} = ${eval1.primitive};
} else {
if (${eval1.primitive} < ${eval2.primitive}) {
${ev.primitive} = ${eval1.primitive};
} else {
${ev.primitive} = ${eval2.primitive};
}
${ev.primitive} = ${eval2.primitive};
}
"""
} else {
super.genCode(ctx, ev)
}
}
"""
}

override def toString: String = s"MinOf($left, $right)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.codehaus.janino.ClassBodyEvaluator

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -176,9 +175,8 @@ class CodeGenContext {
* Generate code for compare expression in Java
*/
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
// Use signum() to keep any small difference bwteen float/double
case FloatType | DoubleType => s"(int)java.lang.Math.signum($c1 - $c2)"
case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 - $c2)"
// use c1 - c2 may overflow
case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case other => s"$c1.compare($c2)"
}
Expand Down Expand Up @@ -266,7 +264,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* weak keys/values and thus does not respond to memory pressure.
*/
protected val cache = CacheBuilder.newBuilder()
.maximumSize(1000)
.maximumSize(100)
.build(
new CacheLoader[InType, OutType]() {
override def load(in: InType): OutType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ case class Alias(child: Expression, name: String)(

override def eval(input: InternalRow): Any = child.eval(input)

override def isThreadSafe: Boolean = child.isThreadSafe

override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)

override def dataType: DataType = child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types.DataType

case class Coalesce(children: Seq[Expression]) extends Expression {
Expand Down Expand Up @@ -53,6 +52,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
result
}

override def isThreadSafe: Boolean = children.forall(_.isThreadSafe)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
s"""
boolean ${ev.isNull} = true;
Expand All @@ -73,7 +74,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
}

case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def foldable: Boolean = child.foldable
override def nullable: Boolean = false

Expand All @@ -91,7 +92,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
override def toString: String = s"IS NULL $child"
}

case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def foldable: Boolean = child.foldable
override def nullable: Boolean = false
override def toString: String = s"IS NOT NULL $child"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{NumericType, DataType}
import org.apache.spark.sql.types.{DataType, NumericType}

/**
* The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters, analysis}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis}
import org.apache.spark.sql.types.{StructField, StructType}

object LocalRelation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.math.BigInteger
import java.sql.{Date, Timestamp}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._

case class PrimitiveData(
Expand Down Expand Up @@ -75,7 +74,7 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) {
}

class ScalaReflectionSuite extends SparkFunSuite {
import ScalaReflection._
import org.apache.spark.sql.catalyst.ScalaReflection._

test("primitive data") {
val schema = schemaFor[PrimitiveData]
Expand Down
10 changes: 3 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean

/**
* When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
* When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode
* that evaluates expressions found in queries. In general this custom code runs much faster
* than interpreted evaluation, but there are significant start-up costs due to compilation.
* As a result codegen is only beneficial when queries run for a long time, or when the same
* expressions are used multiple times.
*
* Defaults to false as this feature is currently experimental.
* than interpreted evaluation, but there are some start-up costs (5-10ms) due to compilation.
*/
private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean
private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "true").toBoolean

/**
* caseSensitive analysis true by default
Expand Down
5 changes: 2 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{InternalRow, _}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.ParserDialect
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution.{Filter, _}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.columnar

import java.nio.{ByteBuffer, ByteOrder}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.columnar.ColumnBuilder._
import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if (codegenEnabled) {
if (codegenEnabled && expressions.forall(_.isThreadSafe)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so code gen is completely off if I use any thread safe == false function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can improve the coverage later, make all the hive tests passed first.

GenerateProjection.generate(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
Expand All @@ -168,7 +168,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if(codegenEnabled) {
if(codegenEnabled && expressions.forall(_.isThreadSafe)) {

GenerateMutableProjection.generate(expressions, inputSchema)
} else {
() => new InterpretedMutableProjection(expressions, inputSchema)
Expand All @@ -178,7 +179,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ

protected def newPredicate(
expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
if (codegenEnabled) {
if (codegenEnabled && expression.isThreadSafe) {
GeneratePredicate.generate(expression, inputSchema)
} else {
InterpretedPredicate.create(expression, inputSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression {
count += 1
(TaskContext.get().partitionId().toLong << 33) + currentCount
}

override def isThreadSafe: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ case class BroadcastLeftSemiJoinHash(
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
// rowKey may be not serializable (from codegen)
hashSet.add(rowKey.copy())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ import org.apache.parquet.hadoop._
import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.parquet.hadoop.util.ContextUtil

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.{Partition => SparkPartition, SerializableWritable, Logging, SparkException}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDD._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SerializableWritable, SparkException, Partition => SparkPartition}

private[sql] class DefaultSource extends HadoopFsRelationProvider {
override def createRelation(
Expand Down
Loading