Skip to content

Commit d64e290

Browse files
authored
[SPARK-19372][SQL] Fix throwing a Java exception at df.fliter() due to 64KB bytecode size limit (apache#171)
1 parent 20e6d74 commit d64e290

File tree

7 files changed

+67
-12
lines changed

7 files changed

+67
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ import scala.language.existentials
2727
import scala.util.control.NonFatal
2828

2929
import com.google.common.cache.{CacheBuilder, CacheLoader}
30-
import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler}
30+
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
31+
import org.apache.commons.lang3.exception.ExceptionUtils
32+
import org.codehaus.commons.compiler.CompileException
33+
import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler}
3134
import org.codehaus.janino.util.ClassFile
3235

3336
import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
@@ -899,8 +902,20 @@ object CodeGenerator extends Logging {
899902
/**
900903
* Compile the Java source code into a Java class, using Janino.
901904
*/
902-
def compile(code: CodeAndComment): GeneratedClass = {
905+
def compile(code: CodeAndComment): GeneratedClass = try {
903906
cache.get(code)
907+
} catch {
908+
// Cache.get() may wrap the original exception. See the following URL
909+
// http://google.github.io/guava/releases/14.0/api/docs/com/google/common/cache/
910+
// Cache.html#get(K,%20java.util.concurrent.Callable)
911+
case e : UncheckedExecutionException =>
912+
val excChains = ExceptionUtils.getThrowables(e)
913+
val exc = if (excChains.length == 1) excChains(0) else excChains(excChains.length - 2)
914+
throw exc
915+
case e : ExecutionError =>
916+
val excChains = ExceptionUtils.getThrowables(e)
917+
val exc = if (excChains.length == 1) excChains(0) else excChains(excChains.length - 2)
918+
throw exc
904919
}
905920

906921
/**
@@ -951,10 +966,14 @@ object CodeGenerator extends Logging {
951966
evaluator.cook("generated.java", code.body)
952967
recordCompilationStats(evaluator)
953968
} catch {
954-
case e: Exception =>
969+
case e: JaninoRuntimeException =>
955970
val msg = s"failed to compile: $e\n$formatted"
956971
logError(msg, e)
957-
throw new Exception(msg, e)
972+
throw new JaninoRuntimeException(msg, e)
973+
case e: CompileException =>
974+
val msg = s"failed to compile: $e\n$formatted"
975+
logError(msg, e)
976+
throw new CompileException(msg, e.asInstanceOf[CompileException].getLocation)
958977
}
959978
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
960979
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,22 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate}
2324
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2425
import org.apache.spark.sql.catalyst.util.TypeUtils
2526
import org.apache.spark.sql.types._
2627

2728

2829
object InterpretedPredicate {
29-
def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) =
30+
def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate =
3031
create(BindReferences.bindReference(expression, inputSchema))
3132

32-
def create(expression: Expression): (InternalRow => Boolean) = {
33-
(r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
34-
}
33+
def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression)
3534
}
3635

36+
class InterpretedPredicate(expression: Expression) extends BasePredicate {
37+
def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean]
38+
}
3739

3840
/**
3941
* An [[Expression]] that returns a boolean value.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
2222
import scala.collection.mutable.ArrayBuffer
2323
import scala.concurrent.ExecutionContext
2424

25+
import org.codehaus.commons.compiler.CompileException
26+
import org.codehaus.janino.JaninoRuntimeException
27+
2528
import org.apache.spark.{broadcast, SparkEnv}
2629
import org.apache.spark.internal.Logging
2730
import org.apache.spark.io.CompressionCodec
@@ -353,9 +356,28 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
353356
GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination)
354357
}
355358

359+
private def genInterpretedPredicate(
360+
expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = {
361+
val str = expression.toString
362+
val logMessage = if (str.length > 256) {
363+
str.substring(0, 256 - 3) + "..."
364+
} else {
365+
str
366+
}
367+
logWarning(s"Codegen disabled for this expression:\n $logMessage")
368+
InterpretedPredicate.create(expression, inputSchema)
369+
}
370+
356371
protected def newPredicate(
357372
expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = {
358-
GeneratePredicate.generate(expression, inputSchema)
373+
try {
374+
GeneratePredicate.generate(expression, inputSchema)
375+
} catch {
376+
case e: JaninoRuntimeException if sqlContext == null || sqlContext.conf.wholeStageFallback =>
377+
genInterpretedPredicate(expression, inputSchema)
378+
case e: CompileException if sqlContext == null || sqlContext.conf.wholeStageFallback =>
379+
genInterpretedPredicate(expression, inputSchema)
380+
}
359381
}
360382

361383
protected def newOrdering(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ abstract class PartitioningAwareFileIndex(
180180
})
181181

182182
val selected = partitions.filter {
183-
case PartitionPath(values, _) => boundPredicate(values)
183+
case PartitionPath(values, _) => boundPredicate.eval(values)
184184
}
185185
logInfo {
186186
val total = partitions.length

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import java.util.UUID
2424

2525
import scala.util.Random
2626

27+
import com.sun.net.httpserver.Authenticator.Retry
2728
import org.scalatest.Matchers._
2829

2930
import org.apache.spark.SparkException
@@ -1703,4 +1704,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
17031704
val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
17041705
checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
17051706
}
1707+
1708+
test("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") {
1709+
val N = 400
1710+
val rows = Seq(Row.fromSeq(Seq.fill(N)("string")))
1711+
val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType)))
1712+
val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema)
1713+
1714+
val filter = (0 until N)
1715+
.foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string"))
1716+
df.filter(filter).count
1717+
}
17061718
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
10361036
BoundReference(index, partitionSchema(index).dataType, nullable = true)
10371037
})
10381038
clientPrunedPartitions.filter { p =>
1039-
boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
1039+
boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId))
10401040
}
10411041
} else {
10421042
client.getPartitions(catalogTable).map { part =>

sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
103103
// `Cast`ed values are always of internal types (e.g. UTF8String instead of String)
104104
Cast(Literal(value), dataType).eval()
105105
})
106-
}.filter(predicate).map(projection)
106+
}.filter(predicate.eval).map(projection)
107107

108108
// Appends partition values
109109
val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes

0 commit comments

Comments
 (0)