diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 0d0269f694300..78cff99f1edb4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} import org.apache.spark.util.Utils @@ -66,6 +66,8 @@ private[sql] class HiveSessionCatalog( * Construct a [[FunctionBuilder]] based on the provided class that represents a function. */ private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = { + validateHiveUserDefinedFunction(clazz) + // When we instantiate hive UDF wrapper class, we may throw exception if the input // expressions don't satisfy the hive UDF, such as type mismatch, input number // mismatch, etc. Here we catch the exception and throw AnalysisException instead. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 9e9894803ce25..9e6a1e7434e59 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -22,15 +22,16 @@ import java.rmi.server.UID import scala.collection.JavaConverters._ import scala.language.implicitConversions -import scala.reflect.ClassTag +import scala.reflect.{classTag, ClassTag} +import scala.util.control.NonFatal import com.google.common.base.Objects import org.apache.avro.Schema import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.exec.{MapredContext, UDF, Utilities} import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDF, GenericUDFMacro, GenericUDTF} import org.apache.hadoop.hive.serde2.ColumnProjectionUtils import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector @@ -42,7 +43,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.types.Decimal import org.apache.spark.util.Utils -private[hive] object HiveShim { +private[hive] object HiveShim extends Logging { // Precision and scale to pass for unlimited decimals; these are the same as the precision and // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) val UNLIMITED_DECIMAL_PRECISION = 38 @@ -111,6 +112,32 @@ private[hive] object HiveShim { } } + private def hasInheritanceOf[UDFType: ClassTag](func: String, clazz: Class[_]): Boolean = { + val parentClazz = classTag[UDFType].runtimeClass + if (parentClazz.isAssignableFrom(clazz)) { + try { + val funcClass = clazz.getMethod(func, classOf[MapredContext]) + // If a given `func` not overridden, `Method.getDeclaringClass` returns + // a parent Class object. + funcClass.getDeclaringClass != parentClazz + } catch { + case NonFatal(_) => false + } + } else { + false + } + } + + def validateHiveUserDefinedFunction(udfClass: Class[_]): Unit = { + if (hasInheritanceOf[GenericUDF]("configure", udfClass) || + hasInheritanceOf[GenericUDTF]("configure", udfClass)) { + logWarning(s"Found an overridden method `configure` in ${udfClass.getSimpleName}, but " + + "Spark does not call the method during initialization because Spark does not use " + + "MapredContext inside (See SPARK-21533). So, you might reconsider the implementation of " + + s"${udfClass.getSimpleName}.") + } + } + /** * This class provides the UDF creation and also the UDF instance serialization and * de-serialization cross process boundary.