Skip to content

Commit 44573a3

Browse files
author
Davies Liu
committed
check thread safety of expression
1 parent f3886fa commit 44573a3

File tree

13 files changed

+102
-67
lines changed

13 files changed

+102
-67
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
460460
case (BooleanType, dt: NumericType) =>
461461
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
462462
case (dt: DecimalType, BooleanType) =>
463-
defineCodeGen(ctx, ev, c => s"$c.isZero()")
463+
defineCodeGen(ctx, ev, c => s"!($c).equals(0)")
464464
case (dt: NumericType, BooleanType) =>
465465
defineCodeGen(ctx, ev, c => s"$c != 0")
466466

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ abstract class Expression extends TreeNode[Expression] {
5252
/** Returns the result of evaluating this expression on a given input Row */
5353
def eval(input: Row = null): Any
5454

55+
/**
56+
* Return true if this expression is thread-safe, which means it could be used by multiple
57+
* threads in the same time.
58+
*
59+
* An expression that is not thread-safe can not be cached and re-used, especially for codegen.
60+
*/
61+
def isThreadSafe: Boolean = true
62+
5563
/**
5664
* Returns an [[GeneratedExpressionCode]], which contains Java source code that
5765
* can be used to generate the result of evaluating the expression on an input row.
@@ -60,6 +68,9 @@ abstract class Expression extends TreeNode[Expression] {
6068
* @return [[GeneratedExpressionCode]]
6169
*/
6270
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
71+
if (!isThreadSafe) {
72+
throw new Exception(s"$this is not thread-safe, can not be used in codegen")
73+
}
6374
val isNull = ctx.freshName("isNull")
6475
val primitive = ctx.freshName("primitive")
6576
val ve = GeneratedExpressionCode("", isNull, primitive)
@@ -156,6 +167,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
156167

157168
override def toString: String = s"($left $symbol $right)"
158169

170+
override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe
159171
/**
160172
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
161173
* the same type. If either of the sub-expressions is null, the result of this computation
@@ -203,6 +215,8 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
203215
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
204216
self: Product =>
205217

218+
override def isThreadSafe: Boolean = child.isThreadSafe
219+
206220
/**
207221
* Called by unary expressions to generate a code block that returns null if its parent returns
208222
* null, and if not not null, use `f` to generate the expression.
@@ -240,6 +254,7 @@ case class GroupExpression(children: Seq[Expression]) extends Expression {
240254
override def nullable: Boolean = false
241255
override def foldable: Boolean = false
242256
override def dataType: DataType = throw new UnsupportedOperationException
257+
override def isThreadSafe: Boolean = children.forall(_.isThreadSafe)
243258
}
244259

245260
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -347,31 +347,29 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
347347
}
348348

349349
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
350-
if (ctx.isNativeType(left.dataType)) {
351-
val eval1 = left.gen(ctx)
352-
val eval2 = right.gen(ctx)
353-
eval1.code + eval2.code + s"""
354-
boolean ${ev.isNull} = false;
355-
${ctx.javaType(left.dataType)} ${ev.primitive} =
356-
${ctx.defaultValue(left.dataType)};
357-
358-
if (${eval1.isNull}) {
359-
${ev.isNull} = ${eval2.isNull};
360-
${ev.primitive} = ${eval2.primitive};
361-
} else if (${eval2.isNull}) {
362-
${ev.isNull} = ${eval1.isNull};
350+
val eval1 = left.gen(ctx)
351+
val eval2 = right.gen(ctx)
352+
val compCode = ctx.compFunc(dataType)(eval1.primitive, eval2.primitive)
353+
354+
eval1.code + eval2.code + s"""
355+
boolean ${ev.isNull} = false;
356+
${ctx.javaType(left.dataType)} ${ev.primitive} =
357+
${ctx.defaultValue(left.dataType)};
358+
359+
if (${eval1.isNull}) {
360+
${ev.isNull} = ${eval2.isNull};
361+
${ev.primitive} = ${eval2.primitive};
362+
} else if (${eval2.isNull}) {
363+
${ev.isNull} = ${eval1.isNull};
364+
${ev.primitive} = ${eval1.primitive};
365+
} else {
366+
if ($compCode > 0) {
363367
${ev.primitive} = ${eval1.primitive};
364368
} else {
365-
if (${eval1.primitive} > ${eval2.primitive}) {
366-
${ev.primitive} = ${eval1.primitive};
367-
} else {
368-
${ev.primitive} = ${eval2.primitive};
369-
}
369+
${ev.primitive} = ${eval2.primitive};
370370
}
371-
"""
372-
} else {
373-
super.genCode(ctx, ev)
374-
}
371+
}
372+
"""
375373
}
376374
override def toString: String = s"MaxOf($left, $right)"
377375
}
@@ -401,33 +399,29 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
401399
}
402400

403401
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
404-
if (ctx.isNativeType(left.dataType)) {
405-
406-
val eval1 = left.gen(ctx)
407-
val eval2 = right.gen(ctx)
408-
409-
eval1.code + eval2.code + s"""
410-
boolean ${ev.isNull} = false;
411-
${ctx.javaType(left.dataType)} ${ev.primitive} =
412-
${ctx.defaultValue(left.dataType)};
402+
val eval1 = left.gen(ctx)
403+
val eval2 = right.gen(ctx)
404+
val compCode = ctx.compFunc(dataType)(eval1.primitive, eval2.primitive)
413405

414-
if (${eval1.isNull}) {
415-
${ev.isNull} = ${eval2.isNull};
416-
${ev.primitive} = ${eval2.primitive};
417-
} else if (${eval2.isNull}) {
418-
${ev.isNull} = ${eval1.isNull};
406+
eval1.code + eval2.code + s"""
407+
boolean ${ev.isNull} = false;
408+
${ctx.javaType(left.dataType)} ${ev.primitive} =
409+
${ctx.defaultValue(left.dataType)};
410+
411+
if (${eval1.isNull}) {
412+
${ev.isNull} = ${eval2.isNull};
413+
${ev.primitive} = ${eval2.primitive};
414+
} else if (${eval2.isNull}) {
415+
${ev.isNull} = ${eval1.isNull};
416+
${ev.primitive} = ${eval1.primitive};
417+
} else {
418+
if ($compCode < 0) {
419419
${ev.primitive} = ${eval1.primitive};
420420
} else {
421-
if (${eval1.primitive} < ${eval2.primitive}) {
422-
${ev.primitive} = ${eval1.primitive};
423-
} else {
424-
${ev.primitive} = ${eval2.primitive};
425-
}
421+
${ev.primitive} = ${eval2.primitive};
426422
}
427-
"""
428-
} else {
429-
super.genCode(ctx, ev)
430-
}
423+
}
424+
"""
431425
}
432426

433427
override def toString: String = s"MinOf($left, $right)"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,27 @@ class CodeGenContext {
164164
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
165165
case BinaryType => { case (eval1, eval2) =>
166166
s"java.util.Arrays.equals($eval1, $eval2)" }
167-
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
167+
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType
168+
| DateType =>
168169
{ case (eval1, eval2) => s"$eval1 == $eval2" }
169170
case other =>
170171
{ case (eval1, eval2) => s"$eval1.equals($eval2)" }
171172
}
172173

174+
/**
175+
* Return a function to generate compare expression in Java
176+
*/
177+
def compFunc(dataType: DataType): (String, String) => String = dataType match {
178+
case BinaryType => {
179+
case (c1, c2) =>
180+
s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
181+
}
182+
case IntegerType | LongType | DoubleType | FloatType | ShortType | ByteType | DateType => {
183+
case (c1, c2) => s"$c1 - $c2"
184+
}
185+
case other => { case (c1, c2) => s"$c1.compare($c2)" }
186+
}
187+
173188
/**
174189
* List of data types that have special accessors and setters in [[Row]].
175190
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.Logging
2121
import org.apache.spark.annotation.Private
2222
import org.apache.spark.sql.Row
2323
import org.apache.spark.sql.catalyst.expressions._
24-
import org.apache.spark.sql.types.{DateType, BinaryType, NumericType}
24+
import org.apache.spark.sql.types.{DateType, BinaryType, DecimalType, NumericType}
2525

2626
/**
2727
* Inherits some default implementation for Java from `Ordering[Row]`
@@ -66,7 +66,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
6666
return d;
6767
}
6868
}"""
69-
case _: NumericType | DateType =>
69+
case dt: NumericType if !dt.isInstanceOf[DecimalType] =>
70+
s"""
71+
if (${evalA.primitive} != ${evalB.primitive}) {
72+
if (${evalA.primitive} > ${evalB.primitive}) {
73+
return ${if (asc) "1" else "-1"};
74+
} else {
75+
return ${if (asc) "-1" else "1"};
76+
}
77+
}"""
78+
case DateType =>
7079
s"""
7180
if (${evalA.primitive} != ${evalB.primitive}) {
7281
if (${evalA.primitive} > ${evalB.primitive}) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ case class Alias(child: Expression, name: String)(
116116

117117
override def eval(input: Row): Any = child.eval(input)
118118

119+
override def isThreadSafe: Boolean = child.isThreadSafe
120+
119121
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
120122

121123
override def dataType: DataType = child.dataType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
21-
import org.apache.spark.sql.catalyst.trees
2220
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2322
import org.apache.spark.sql.types.DataType
2423

2524
case class Coalesce(children: Seq[Expression]) extends Expression {
@@ -53,6 +52,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
5352
result
5453
}
5554

55+
override def isThreadSafe: Boolean = children.forall(_.isThreadSafe)
56+
5657
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
5758
s"""
5859
boolean ${ev.isNull} = true;
@@ -73,7 +74,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
7374
}
7475
}
7576

76-
case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
77+
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
7778
override def foldable: Boolean = child.foldable
7879
override def nullable: Boolean = false
7980

@@ -91,7 +92,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
9192
override def toString: String = s"IS NULL $child"
9293
}
9394

94-
case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
95+
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
9596
override def foldable: Boolean = child.foldable
9697
override def nullable: Boolean = false
9798
override def toString: String = s"IS NOT NULL $child"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2121
import org.apache.spark.sql.catalyst.errors.TreeNodeException
22-
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.types.{NumericType, DataType}
22+
import org.apache.spark.sql.types.{DataType, NumericType}
2423

2524
/**
2625
* The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
154154
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
155155
log.debug(
156156
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
157-
if (codegenEnabled) {
157+
if (codegenEnabled && expressions.forall(_.isThreadSafe)) {
158158
GenerateProjection.generate(expressions, inputSchema)
159159
} else {
160160
new InterpretedProjection(expressions, inputSchema)
@@ -166,7 +166,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
166166
inputSchema: Seq[Attribute]): () => MutableProjection = {
167167
log.debug(
168168
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
169-
if(codegenEnabled) {
169+
if(codegenEnabled && expressions.forall(_.isThreadSafe)) {
170+
170171
GenerateMutableProjection.generate(expressions, inputSchema)
171172
} else {
172173
() => new InterpretedMutableProjection(expressions, inputSchema)
@@ -176,7 +177,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
176177

177178
protected def newPredicate(
178179
expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
179-
if (codegenEnabled) {
180+
if (codegenEnabled && expression.isThreadSafe) {
180181
GeneratePredicate.generate(expression, inputSchema)
181182
} else {
182183
InterpretedPredicate.create(expression, inputSchema)

sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,5 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression {
4949
(TaskContext.get().partitionId().toLong << 33) + currentCount
5050
}
5151

52-
/**
53-
* This expression is stateful and not thread safe, it can not be reused. Change equals() to
54-
* always return `false`, make the generated expressions to not be cached.
55-
*/
56-
override def equals(other: Any): Boolean = {
57-
false
58-
}
52+
override def isThreadSafe: Boolean = false
5953
}

0 commit comments

Comments
 (0)