Skip to content

Commit 72ba0fc

Browse files
committed
[SPARK-8154][SQL] Remove Term/Code type aliases in code generation.
From my perspective as a code reviewer, I find them more confusing than using String directly. Author: Reynold Xin <[email protected]> Closes #6694 from rxin/SPARK-8154 and squashes the following commits: 4e5056c [Reynold Xin] [SPARK-8154][SQL] Remove Term/Code type aliases in code generation.
1 parent f74be74 commit 72ba0fc

File tree

15 files changed

+69
-66
lines changed

15 files changed

+69
-66
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.Logging
2121
import org.apache.spark.sql.catalyst.errors.attachTree
22-
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
2323
import org.apache.spark.sql.types._
2424
import org.apache.spark.sql.catalyst.trees
2525

@@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
4343

4444
override def exprId: ExprId = throw new UnsupportedOperationException
4545

46-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
46+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
4747
s"""
4848
boolean ${ev.isNull} = i.isNullAt($ordinal);
4949
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
2121
import java.text.{DateFormat, SimpleDateFormat}
2222

2323
import org.apache.spark.Logging
24-
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
24+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2525
import org.apache.spark.sql.catalyst.util.DateUtils
2626
import org.apache.spark.sql.types._
2727

@@ -435,7 +435,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
435435
if (evaluated == null) null else cast(evaluated)
436436
}
437437

438-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
438+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
439439
// TODO(cg): Add support for more data types.
440440
(child.dataType, dataType) match {
441441

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
21-
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
2222
import org.apache.spark.sql.catalyst.trees
2323
import org.apache.spark.sql.catalyst.trees.TreeNode
2424
import org.apache.spark.sql.types._
@@ -76,7 +76,7 @@ abstract class Expression extends TreeNode[Expression] {
7676
* @param ev an [[GeneratedExpressionCode]] with unique terms.
7777
* @return Java source code
7878
*/
79-
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
79+
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
8080
ctx.references += this
8181
val objectTerm = ctx.freshName("obj")
8282
s"""
@@ -166,7 +166,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
166166
protected def defineCodeGen(
167167
ctx: CodeGenContext,
168168
ev: GeneratedExpressionCode,
169-
f: (Term, Term) => Code): String = {
169+
f: (String, String) => String): String = {
170170
// TODO: Right now some timestamp tests fail if we enforce this...
171171
if (left.dataType != right.dataType) {
172172
// log.warn(s"${left.dataType} != ${right.dataType}")
@@ -182,7 +182,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
182182
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
183183
if (!${ev.isNull}) {
184184
${eval2.code}
185-
if(!${eval2.isNull}) {
185+
if (!${eval2.isNull}) {
186186
${ev.primitive} = $resultCode;
187187
} else {
188188
${ev.isNull} = true;
@@ -217,7 +217,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
217217
protected def defineCodeGen(
218218
ctx: CodeGenContext,
219219
ev: GeneratedExpressionCode,
220-
f: Term => Code): Code = {
220+
f: String => String): String = {
221221
val eval = child.gen(ctx)
222222
// reuse the previous isNull
223223
ev.isNull = eval.isNull

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21-
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
2222
import org.apache.spark.sql.catalyst.util.TypeUtils
2323
import org.apache.spark.sql.types._
2424

@@ -50,7 +50,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
5050

5151
private lazy val numeric = TypeUtils.getNumeric(dataType)
5252

53-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
53+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
5454
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
5555
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
5656
}
@@ -74,7 +74,7 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
7474
else math.sqrt(value)
7575
}
7676

77-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
77+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
7878
val eval = child.gen(ctx)
7979
eval.code + s"""
8080
boolean ${ev.isNull} = ${eval.isNull};
@@ -138,7 +138,7 @@ abstract class BinaryArithmetic extends BinaryExpression {
138138
}
139139
}
140140

141-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
141+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
142142
case dt: DecimalType =>
143143
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
144144
// byte and short are casted into int when add, minus, times or divide
@@ -236,7 +236,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
236236
/**
237237
* Special case handling due to division by 0 => null.
238238
*/
239-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
239+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
240240
val eval1 = left.gen(ctx)
241241
val eval2 = right.gen(ctx)
242242
val test = if (left.dataType.isInstanceOf[DecimalType]) {
@@ -296,7 +296,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
296296
/**
297297
* Special case handling for x % 0 ==> null.
298298
*/
299-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
299+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
300300
val eval1 = left.gen(ctx)
301301
val eval2 = right.gen(ctx)
302302
val test = if (left.dataType.isInstanceOf[DecimalType]) {
@@ -346,7 +346,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
346346
}
347347
}
348348

349-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
349+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
350350
if (ctx.isNativeType(left.dataType)) {
351351
val eval1 = left.gen(ctx)
352352
val eval2 = right.gen(ctx)
@@ -400,7 +400,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
400400
}
401401
}
402402

403-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
403+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
404404
if (ctx.isNativeType(left.dataType)) {
405405

406406
val eval1 = left.gen(ctx)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import org.apache.spark.sql.types._
2525

2626
/**
2727
* A function that calculates bitwise and(&) of two numbers.
28+
*
29+
* Code generation inherited from BinaryArithmetic.
2830
*/
2931
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
3032
override def symbol: String = "&"
@@ -48,6 +50,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
4850

4951
/**
5052
* A function that calculates bitwise or(|) of two numbers.
53+
*
54+
* Code generation inherited from BinaryArithmetic.
5155
*/
5256
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
5357
override def symbol: String = "|"
@@ -71,6 +75,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
7175

7276
/**
7377
* A function that calculates bitwise xor of two numbers.
78+
*
79+
* Code generation inherited from BinaryArithmetic.
7480
*/
7581
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
7682
override def symbol: String = "^"
@@ -112,8 +118,8 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
112118
((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
113119
}
114120

115-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
116-
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)")
121+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
122+
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
117123
}
118124

119125
protected override def evalInternal(evalE: Any) = not(evalE)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
4040
* @param primitive A term for a possible primitive value of the result of the evaluation. Not
4141
* valid if `isNull` is set to `true`.
4242
*/
43-
case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term)
43+
case class GeneratedExpressionCode(var code: String, var isNull: String, var primitive: String)
4444

4545
/**
4646
* A context for codegen, which is used to bookkeeping the expressions those are not supported
@@ -65,14 +65,14 @@ class CodeGenContext {
6565
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
6666
* function.)
6767
*/
68-
def freshName(prefix: String): Term = {
68+
def freshName(prefix: String): String = {
6969
s"$prefix${curId.getAndIncrement}"
7070
}
7171

7272
/**
7373
* Return the code to access a column for given DataType
7474
*/
75-
def getColumn(dataType: DataType, ordinal: Int): Code = {
75+
def getColumn(dataType: DataType, ordinal: Int): String = {
7676
if (isNativeType(dataType)) {
7777
s"i.${accessorForType(dataType)}($ordinal)"
7878
} else {
@@ -83,7 +83,7 @@ class CodeGenContext {
8383
/**
8484
* Return the code to update a column in Row for given DataType
8585
*/
86-
def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = {
86+
def setColumn(dataType: DataType, ordinal: Int, value: String): String = {
8787
if (isNativeType(dataType)) {
8888
s"${mutatorForType(dataType)}($ordinal, $value)"
8989
} else {
@@ -94,23 +94,23 @@ class CodeGenContext {
9494
/**
9595
* Return the name of accessor in Row for a DataType
9696
*/
97-
def accessorForType(dt: DataType): Term = dt match {
97+
def accessorForType(dt: DataType): String = dt match {
9898
case IntegerType => "getInt"
9999
case other => s"get${boxedType(dt)}"
100100
}
101101

102102
/**
103103
* Return the name of mutator in Row for a DataType
104104
*/
105-
def mutatorForType(dt: DataType): Term = dt match {
105+
def mutatorForType(dt: DataType): String = dt match {
106106
case IntegerType => "setInt"
107107
case other => s"set${boxedType(dt)}"
108108
}
109109

110110
/**
111111
* Return the Java type for a DataType
112112
*/
113-
def javaType(dt: DataType): Term = dt match {
113+
def javaType(dt: DataType): String = dt match {
114114
case IntegerType => "int"
115115
case LongType => "long"
116116
case ShortType => "short"
@@ -131,7 +131,7 @@ class CodeGenContext {
131131
/**
132132
* Return the boxed type in Java
133133
*/
134-
def boxedType(dt: DataType): Term = dt match {
134+
def boxedType(dt: DataType): String = dt match {
135135
case IntegerType => "Integer"
136136
case LongType => "Long"
137137
case ShortType => "Short"
@@ -146,7 +146,7 @@ class CodeGenContext {
146146
/**
147147
* Return the representation of default value for given DataType
148148
*/
149-
def defaultValue(dt: DataType): Term = dt match {
149+
def defaultValue(dt: DataType): String = dt match {
150150
case BooleanType => "false"
151151
case FloatType => "-1.0f"
152152
case ShortType => "(short)-1"
@@ -161,7 +161,7 @@ class CodeGenContext {
161161
/**
162162
* Returns a function to generate equal expression in Java
163163
*/
164-
def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match {
164+
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
165165
case BinaryType => { case (eval1, eval2) =>
166166
s"java.util.Arrays.equals($eval1, $eval2)" }
167167
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ import org.apache.spark.util.Utils
2727
*/
2828
package object codegen {
2929

30-
type Term = String
31-
type Code = String
32-
3330
/** Canonicalizes an expression so those that differ only by names can reuse the same code. */
3431
object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
3532
val batches =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
5050
}
5151
}
5252

53-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
53+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
5454
val condEval = predicate.gen(ctx)
5555
val trueEval = trueValue.gen(ctx)
5656
val falseEval = falseValue.gen(ctx)
@@ -155,7 +155,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
155155
return res
156156
}
157157

158-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
158+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
159159
val len = branchesArr.length
160160
val got = ctx.freshName("got")
161161

@@ -248,7 +248,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
248248
return res
249249
}
250250

251-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
251+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
252252
val keyEval = key.gen(ctx)
253253
val len = branchesArr.length
254254
val got = ctx.freshName("got")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
20+
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
2121
import org.apache.spark.sql.types._
2222

2323
/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
@@ -37,7 +37,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
3737
}
3838
}
3939

40-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
40+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
4141
defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
4242
}
4343
}
@@ -59,7 +59,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
5959
}
6060
}
6161

62-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
62+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
6363
val eval = child.gen(ctx)
6464
eval.code + s"""
6565
boolean ${ev.isNull} = ${eval.isNull};

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import java.sql.{Date, Timestamp}
2121

2222
import org.apache.spark.sql.catalyst.CatalystTypeConverters
23-
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2424
import org.apache.spark.sql.catalyst.util.DateUtils
2525
import org.apache.spark.sql.types._
2626

@@ -88,7 +88,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
8888

8989
override def eval(input: Row): Any = value
9090

91-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
91+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
9292
// change the isNull and primitive to consts, to inline them
9393
if (value == null) {
9494
ev.isNull = "true"

0 commit comments

Comments
 (0)