diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 29585acd4759a..73ae66c09c56c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} - +import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalogBuilder: () => ExternalCatalog, @@ -65,49 +65,52 @@ private[sql] class HiveSessionCatalog( name: String, clazz: Class[_], input: Seq[Expression]): Expression = { - - Try(super.makeFunctionExpression(name, clazz, input)).getOrElse { - var udfExpr: Option[Expression] = None - try { - // When we instantiate hive UDF wrapper class, we may throw exception if the input - // expressions don't satisfy the hive UDF, such as type mismatch, input number - // mismatch, etc. Here we catch the exception and throw AnalysisException instead. - if (classOf[UDF].isAssignableFrom(clazz)) { - udfExpr = Some(HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input)) - udfExpr.get.dataType // Force it to check input data types. - } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { - udfExpr = Some(HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input)) - udfExpr.get.dataType // Force it to check input data types. - } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { - udfExpr = Some(HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input)) - udfExpr.get.dataType // Force it to check input data types. - } else if (classOf[UDAF].isAssignableFrom(clazz)) { - udfExpr = Some(HiveUDAFFunction( - name, - new HiveFunctionWrapper(clazz.getName), - input, - isUDAFBridgeRequired = true)) - udfExpr.get.dataType // Force it to check input data types. - } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { - udfExpr = Some(HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input)) - udfExpr.get.asInstanceOf[HiveGenericUDTF].elementSchema // Force it to check data types. + // Current thread context classloader may not be the one loaded the class. Need to switch + // context classloader to initialize instance properly. + Utils.withContextClassLoader(clazz.getClassLoader) { + Try(super.makeFunctionExpression(name, clazz, input)).getOrElse { + var udfExpr: Option[Expression] = None + try { + // When we instantiate hive UDF wrapper class, we may throw exception if the input + // expressions don't satisfy the hive UDF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + if (classOf[UDF].isAssignableFrom(clazz)) { + udfExpr = Some(HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. + } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { + udfExpr = Some(HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. + } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { + udfExpr = Some(HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. + } else if (classOf[UDAF].isAssignableFrom(clazz)) { + udfExpr = Some(HiveUDAFFunction( + name, + new HiveFunctionWrapper(clazz.getName), + input, + isUDAFBridgeRequired = true)) + udfExpr.get.dataType // Force it to check input data types. + } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + udfExpr = Some(HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.asInstanceOf[HiveGenericUDTF].elementSchema // Force it to check data types. + } + } catch { + case NonFatal(e) => + val noHandlerMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" + val errorMsg = + if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + s"$noHandlerMsg\nPlease make sure your function overrides " + + "`public StructObjectInspector initialize(ObjectInspector[] args)`." + } else { + noHandlerMsg + } + val analysisException = new AnalysisException(errorMsg) + analysisException.setStackTrace(e.getStackTrace) + throw analysisException + } + udfExpr.getOrElse { + throw new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") } - } catch { - case NonFatal(e) => - val noHandlerMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" - val errorMsg = - if (classOf[GenericUDTF].isAssignableFrom(clazz)) { - s"$noHandlerMsg\nPlease make sure your function overrides " + - "`public StructObjectInspector initialize(ObjectInspector[] args)`." - } else { - noHandlerMsg - } - val analysisException = new AnalysisException(errorMsg) - analysisException.setStackTrace(e.getStackTrace) - throw analysisException - } - udfExpr.getOrElse { - throw new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") } } } diff --git a/sql/hive/src/test/noclasspath/README b/sql/hive/src/test/noclasspath/README new file mode 100644 index 0000000000000..8ce1b0bd09668 --- /dev/null +++ b/sql/hive/src/test/noclasspath/README @@ -0,0 +1 @@ +Place files which are being used as resources of tests but shouldn't be added to classpath. \ No newline at end of file diff --git a/sql/hive/src/test/noclasspath/TestUDTF-spark-26560.jar b/sql/hive/src/test/noclasspath/TestUDTF-spark-26560.jar new file mode 100644 index 0000000000000..b73b17d5c7880 Binary files /dev/null and b/sql/hive/src/test/noclasspath/TestUDTF-spark-26560.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2fcc8229a9f57..9f57776233e2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.Utils case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -2327,4 +2328,51 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-26560 Spark should be able to run Hive UDF using jar regardless of " + + "current thread context classloader") { + // force to use Spark classloader as other test (even in other test suites) may change the + // current thread's context classloader to jar classloader + Utils.withContextClassLoader(Utils.getSparkClassLoader) { + withUserDefinedFunction("udtf_count3" -> false) { + val sparkClassLoader = Thread.currentThread().getContextClassLoader + + // This jar file should not be placed to the classpath; GenericUDTFCount3 is slightly + // modified version of GenericUDTFCount2 in hive/contrib, which emits the count for + // three times. + val jarPath = "src/test/noclasspath/TestUDTF-spark-26560.jar" + val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" + + sql( + s""" + |CREATE FUNCTION udtf_count3 + |AS 'org.apache.hadoop.hive.contrib.udtf.example.GenericUDTFCount3' + |USING JAR '$jarURL' + """.stripMargin) + + assert(Thread.currentThread().getContextClassLoader eq sparkClassLoader) + + // JAR will be loaded at first usage, and it will change the current thread's + // context classloader to jar classloader in sharedState. + // See SessionState.addJar for details. + checkAnswer( + sql("SELECT udtf_count3(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Row(3) :: Nil) + + assert(Thread.currentThread().getContextClassLoader ne sparkClassLoader) + assert(Thread.currentThread().getContextClassLoader eq + spark.sqlContext.sharedState.jarClassLoader) + + // Roll back to the original classloader and run query again. Without this line, the test + // would pass, as thread's context classloader is changed to jar classloader. But thread + // context classloader can be changed from others as well which would fail the query; one + // example is spark-shell, which thread context classloader rolls back automatically. This + // mimics the behavior of spark-shell. + Thread.currentThread().setContextClassLoader(sparkClassLoader) + checkAnswer( + sql("SELECT udtf_count3(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Row(3) :: Nil) + } + } + } }