From 2421c72cce895c3714f5d8f15cf2675bf49c9d24 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 16 Mar 2020 22:35:12 +0800 Subject: [PATCH 01/11] init --- .../spark/sql/catalyst/ScalaReflection.scala | 9 +++ .../sql/catalyst/analysis/Analyzer.scala | 38 +++++++++- .../sql/catalyst/expressions/ScalaUDF.scala | 1 + .../apache/spark/sql/UDFRegistration.scala | 69 ++++++++++++------- .../sql/expressions/UserDefinedFunction.scala | 2 + .../org/apache/spark/sql/functions.scala | 33 ++++++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 8 +++ 7 files changed, 125 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1f7634bafa420..5cda30e5f0a35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -611,6 +611,15 @@ object ScalaReflection extends ScalaReflection { } } + def getClassForCaseClass[T: TypeTag]: Option[Class[_]] = { + val tpe = localTypeOf[T] + if (isSubtype(tpe.dealias, localTypeOf[Product])) { + Some(getClassFromType(tpe)) + } else { + None + } + } + /* * Retrieves the runtime class corresponding to the provided type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index edcfd6fe8ab61..eb308fe0c9650 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -256,6 +256,9 @@ class Analyzer( Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, + ResolveCaseClassForUDF, + // `ResolveCaseClassForUDF` may generates `NewInstance` so we need to resolve it + ResolveNewInstance, HandleNullInputsForUDF), Batch("UpdateNullability", Once, UpdateAttributeNullability), @@ -2707,7 +2710,7 @@ class Analyzer( case p => p transformExpressionsUp { - case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _) + case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _, _) if inputPrimitives.contains(true) => // Otherwise, add special handling of null for fields that can't accept null. // The result of operations like this, when passed null, is generally to return null. @@ -2741,6 +2744,39 @@ class Analyzer( } } + object ResolveCaseClassForUDF extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p if !p.resolved => p // Skip unresolved nodes. + + case p => p transformExpressionsUp { + + case udf @ ScalaUDF(_, _, inputs, _, inputCaseClass, _, _, _, _) => + if (inputCaseClass.exists(_.isDefined)) { + assert(inputs.size == inputCaseClass.size) + val newInputs = inputs.zip(inputCaseClass).map { + case (input, clazzOpt) => + if (clazzOpt.isDefined) { + val clazz = clazzOpt.get + assert(input.dataType.isInstanceOf[StructType], + s"expects StructType, but got ${input.dataType}") + val dataType = input.dataType.asInstanceOf[StructType] + val args = dataType.toAttributes.zipWithIndex.map { case (a, i) => + GetStructField(input, i, Some(a.name)) + } + NewInstance(clazz, args, ObjectType(clazz)) + } else { + input + } + } + // FIXME(wuyi): why applied 2 times? + udf.copy(children = newInputs, inputCaseClass = Nil) + } else { + udf + } + } + } + } + /** * Check and add proper window frames for all window functions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 10f8ec9617d1b..b399793245f9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,6 +49,7 @@ case class ScalaUDF( dataType: DataType, children: Seq[Expression], inputPrimitives: Seq[Boolean], + inputCaseClass: Seq[Option[Class[_]]] = Nil, inputTypes: Seq[AbstractDataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 0f08e10c00d22..71bce227929e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -181,7 +181,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClass: Seq[Option[Class[_]]] = Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClass).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 0) { finalUdf.createScalaUDF(e) @@ -201,7 +202,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 1) { finalUdf.createScalaUDF(e) @@ -221,7 +223,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 2) { finalUdf.createScalaUDF(e) @@ -241,7 +244,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 3) { finalUdf.createScalaUDF(e) @@ -261,7 +265,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 4) { finalUdf.createScalaUDF(e) @@ -281,7 +286,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 5) { finalUdf.createScalaUDF(e) @@ -301,7 +307,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 6) { finalUdf.createScalaUDF(e) @@ -321,7 +328,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 7) { finalUdf.createScalaUDF(e) @@ -341,7 +349,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 8) { finalUdf.createScalaUDF(e) @@ -361,7 +370,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 9) { finalUdf.createScalaUDF(e) @@ -381,7 +391,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 10) { finalUdf.createScalaUDF(e) @@ -401,7 +412,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 11) { finalUdf.createScalaUDF(e) @@ -421,7 +433,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 12) { finalUdf.createScalaUDF(e) @@ -441,7 +454,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 13) { finalUdf.createScalaUDF(e) @@ -461,7 +475,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 14) { finalUdf.createScalaUDF(e) @@ -481,7 +496,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 15) { finalUdf.createScalaUDF(e) @@ -501,7 +517,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 16) { finalUdf.createScalaUDF(e) @@ -521,7 +538,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 17) { finalUdf.createScalaUDF(e) @@ -541,7 +559,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 18) { finalUdf.createScalaUDF(e) @@ -561,7 +580,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: ScalaReflection.getClassForCaseClass[A19] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 19) { finalUdf.createScalaUDF(e) @@ -581,7 +601,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: ScalaReflection.getClassForCaseClass[A19] :: ScalaReflection.getClassForCaseClass[A20] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 20) { finalUdf.createScalaUDF(e) @@ -601,7 +622,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: ScalaReflection.getClassForCaseClass[A19] :: ScalaReflection.getClassForCaseClass[A20] :: ScalaReflection.getClassForCaseClass[A21] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 21) { finalUdf.createScalaUDF(e) @@ -621,7 +643,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Try(ScalaReflection.schemaFor[A22]).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: ScalaReflection.getClassForCaseClass[A19] :: ScalaReflection.getClassForCaseClass[A20] :: ScalaReflection.getClassForCaseClass[A21] :: ScalaReflection.getClassForCaseClass[A22] :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 22) { finalUdf.createScalaUDF(e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index c50168cf7ac13..8455b0d593ee5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -94,6 +94,7 @@ private[spark] case class SparkUserDefinedFunction( f: AnyRef, dataType: DataType, inputSchemas: Seq[Option[ScalaReflection.Schema]], + inputCaseClass: Seq[Option[Class[_]]] = Nil, name: Option[String] = None, nullable: Boolean = true, deterministic: Boolean = true) extends UserDefinedFunction { @@ -115,6 +116,7 @@ private[spark] case class SparkUserDefinedFunction( dataType, exprs, inputsPrimitive, + inputCaseClass, inputTypes, udfName = name, nullable = nullable, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e2d3d55812c51..70fd8cdeb495a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4388,7 +4388,8 @@ object functions { def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4404,7 +4405,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4420,7 +4422,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4436,7 +4439,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4452,7 +4456,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4468,7 +4473,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4484,7 +4490,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4500,7 +4507,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4516,7 +4524,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4532,7 +4541,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } @@ -4548,7 +4558,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) if (nullable) udf else udf.asNonNullable() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e0857ed6bc35a..038e53e28d90a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.QueryExecutionListener +private case class Person(age: Int) private case class FunctionResult(f1: String, f2: String) @@ -551,4 +552,11 @@ class UDFSuite extends QueryTest with SharedSparkSession { } assert(e.getMessage.contains("Invalid arguments for function cast")) } + + test("SPARK-30127: Support input case class in typed Scala UDF") { + val f = (p: Person) => p.age + val myUdf = udf(f) + val df = Seq(("Jack", Person(50))).toDF("name", "age") + checkAnswer(df.select(myUdf(Column("age"))), Row(50) :: Nil) + } } From dd902989815f2b9a5fceebfbe6b78591524e4c60 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 17 Mar 2020 23:54:05 +0800 Subject: [PATCH 02/11] update --- .../spark/sql/catalyst/ScalaReflection.scala | 5 ++- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/expressions/ScalaUDF.scala | 3 ++ .../expressions/objects/objects.scala | 40 +++++++++++++++++-- .../scala/org/apache/spark/sql/UDFSuite.scala | 31 +++++++++++--- 5 files changed, 68 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5cda30e5f0a35..7c87eea967acc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -581,11 +581,12 @@ object ScalaReflection extends ScalaReflection { * Note that it only works for scala classes with primary constructor, and currently doesn't * support inner class. */ - def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = { + // FIXME(wuyi): test on inner class/repl + def getConstructorParameters(cls: Class[_]): Seq[Class[_]] = { val m = runtimeMirror(cls.getClassLoader) val classSymbol = m.staticClass(cls.getName) val t = classSymbol.selfType - getConstructorParameters(t) + getConstructorParameters(t).map { case (_, tpe) => getClassFromType(tpe)} } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index eb308fe0c9650..051baea1ba0a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2768,7 +2768,7 @@ class Analyzer( input } } - // FIXME(wuyi): why applied 2 times? + // assign Nil inputCaseClass to avoid applying this rule for multiple times udf.copy(children = newInputs, inputCaseClass = Nil) } else { udf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index b399793245f9f..f19b7278a888a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -35,6 +35,9 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType} * UDF return null if there is any null input value of these types. On the * other hand, Java UDFs can only have boxed types, thus this parameter will * always be all false. + * @param inputCaseClass Includes the Class[_] of case classes from the input parameter. + * If the input parameter is not a case class, then, the corresponding value + * is None. * @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do * not want to perform coercion, simply use "Nil". Note that it would've been * better to use Option of Seq[DataType] so we can use "None" as the case for no diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 54abd09d89ddb..4eaa16231301e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -448,8 +449,20 @@ case class NewInstance( childrenResolved && !needOuterPointer } + private def argConverters(): Seq[Any => Any] = { + val inputTypes = ScalaReflection.expressionJavaClasses(arguments) + val neededTypes = ScalaReflection.getConstructorParameters(cls) + arguments.zip(inputTypes).zip(neededTypes).map { case ((arg, input), needed) => + if (needed.isAssignableFrom(input)) { + identity[Any] _ + } else { + CatalystTypeConverters.createToScalaConverter(arg.dataType) + } + } + } + @transient private lazy val constructor: (Seq[AnyRef]) => Any = { - val paramTypes = ScalaReflection.expressionJavaClasses(arguments) + val paramTypes = ScalaReflection.getConstructorParameters(cls) val getConstructor = (paramClazz: Seq[Class[_]]) => { ScalaReflection.findConstructor(cls, paramClazz).getOrElse { sys.error(s"Couldn't find a valid constructor on $cls") @@ -472,6 +485,10 @@ case class NewInstance( override def eval(input: InternalRow): Any = { val argValues = arguments.map(_.eval(input)) + .zip(argConverters()) + .map { case (arg, converter) => + converter(arg) + } constructor(argValues.map(_.asInstanceOf[AnyRef])) } @@ -480,6 +497,20 @@ case class NewInstance( val (argCode, argString, resultIsNull) = prepareArguments(ctx) + val converterClassName = classOf[Any => Any].getName + val convertersTerm = ctx.addReferenceObj( + "converters", argConverters().toArray, s"$converterClassName[]") + val argTypes = ScalaReflection.getConstructorParameters(cls) + val convertedArgs = argTypes.map { a => + ctx.addMutableState(CodeGenerator.boxedType(a.getSimpleName), "convertedArg") + } + val convertedCode = argString.split(",").zip(argTypes).zipWithIndex.map { + case ((arg, tpe), i) => + s"${convertedArgs(i)} = " + + s"(${CodeGenerator.boxedType(tpe.getSimpleName)}) $convertersTerm[$i].apply($arg);" + }.mkString("\n") + val convertedArgString = convertedArgs.mkString(",") + val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) ev.isNull = resultIsNull @@ -488,16 +519,17 @@ case class NewInstance( // If there are no constructors, the `new` method will fail. In // this case we can try to call the apply method constructor // that might be defined on the companion object. - case 0 => s"$className$$.MODULE$$.apply($argString)" + case 0 => s"$className$$.MODULE$$.apply($convertedArgString)" case _ => outer.map { gen => - s"${gen.value}.new ${cls.getSimpleName}($argString)" + s"${gen.value}.new ${cls.getSimpleName}($convertedArgString)" }.getOrElse { - s"new $className($argString)" + s"new $className($convertedArgString)" } } val code = code""" $argCode + $convertedCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $constructorCall; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 038e53e28d90a..5f33614f50436 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -33,8 +33,6 @@ import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.QueryExecutionListener -private case class Person(age: Int) - private case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest with SharedSparkSession { @@ -553,10 +551,31 @@ class UDFSuite extends QueryTest with SharedSparkSession { assert(e.getMessage.contains("Invalid arguments for function cast")) } - test("SPARK-30127: Support input case class in typed Scala UDF") { - val f = (p: Person) => p.age + test("only one case class parameter") { + val f = (d: TestData) => d.key * d.value.toInt + val myUdf = udf(f) + val df = Seq(("data", TestData(50, "2"))).toDF("col1", "col2") + checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil) + } + + test("one case class with primitive parameter") { + val f = (i: Int, p: TestData) => p.key * i + val myUdf = udf(f) + val df = Seq((2, TestData(50, "data"))).toDF("col1", "col2") + checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(100) :: Nil) + } + + test("multiple case class parameters") { + val f = (d1: TestData, d2: TestData) => d1.key * d2.key + val myUdf = udf(f) + val df = Seq((TestData(10, "d1"), TestData(50, "d2"))).toDF("col1", "col2") + checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(500) :: Nil) + } + + test("input case class parameter and return case class ") { + val f = (d1: TestData) => TestData(d1.key * 2, "copy") val myUdf = udf(f) - val df = Seq(("Jack", Person(50))).toDF("name", "age") - checkAnswer(df.select(myUdf(Column("age"))), Row(50) :: Nil) + val df = Seq(("data", TestData(50, "d2"))).toDF("col1", "col2") + checkAnswer(df.select(myUdf(Column("col2"))), Row(Row(100, "copy")) :: Nil) } } From 2b186bdd46ad229dd337a0405595d09446884145 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 17 Mar 2020 23:58:46 +0800 Subject: [PATCH 03/11] support inner class --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7c87eea967acc..6ff07a570023b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -581,7 +581,7 @@ object ScalaReflection extends ScalaReflection { * Note that it only works for scala classes with primary constructor, and currently doesn't * support inner class. */ - // FIXME(wuyi): test on inner class/repl + // FIXME(wuyi): support inner class def getConstructorParameters(cls: Class[_]): Seq[Class[_]] = { val m = runtimeMirror(cls.getClassLoader) val classSymbol = m.staticClass(cls.getName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4eaa16231301e..03942623eb696 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** From 84e7a7f2c4e542b6f8691e5e402b561fb2ce1acf Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 18 Mar 2020 09:53:59 +0800 Subject: [PATCH 04/11] fix non-static inner class --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6ff07a570023b..8794833c14352 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import java.lang.reflect.Modifier + import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging @@ -578,15 +580,18 @@ object ScalaReflection extends ScalaReflection { /** * Returns the parameter names and types for the primary constructor of this class. * - * Note that it only works for scala classes with primary constructor, and currently doesn't - * support inner class. + * Note that it only works for scala classes with primary constructor. */ - // FIXME(wuyi): support inner class def getConstructorParameters(cls: Class[_]): Seq[Class[_]] = { val m = runtimeMirror(cls.getClassLoader) val classSymbol = m.staticClass(cls.getName) val t = classSymbol.selfType - getConstructorParameters(t).map { case (_, tpe) => getClassFromType(tpe)} + val dropHead = if (cls.isMemberClass && !Modifier.isStatic(cls.getModifiers)) { + 1 + } else { + 0 + } + getConstructorParameters(t).drop(dropHead).map { case (_, tpe) => getClassFromType(tpe) } } /** From d600cf79edf5329fc8187143bf63d834d1598809 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 18 Mar 2020 17:16:55 +0800 Subject: [PATCH 05/11] use expressionencoder --- .../spark/sql/catalyst/ScalaReflection.scala | 23 +- .../sql/catalyst/analysis/Analyzer.scala | 36 -- .../sql/catalyst/expressions/ScalaUDF.scala | 531 +++++++++--------- .../expressions/objects/objects.scala | 39 +- .../apache/spark/sql/UDFRegistration.scala | 92 +-- .../sql/expressions/UserDefinedFunction.scala | 4 +- .../org/apache/spark/sql/functions.scala | 44 +- 7 files changed, 348 insertions(+), 421 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8794833c14352..1f7634bafa420 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst -import java.lang.reflect.Modifier - import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging @@ -580,18 +578,14 @@ object ScalaReflection extends ScalaReflection { /** * Returns the parameter names and types for the primary constructor of this class. * - * Note that it only works for scala classes with primary constructor. + * Note that it only works for scala classes with primary constructor, and currently doesn't + * support inner class. */ - def getConstructorParameters(cls: Class[_]): Seq[Class[_]] = { + def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = { val m = runtimeMirror(cls.getClassLoader) val classSymbol = m.staticClass(cls.getName) val t = classSymbol.selfType - val dropHead = if (cls.isMemberClass && !Modifier.isStatic(cls.getModifiers)) { - 1 - } else { - 0 - } - getConstructorParameters(t).drop(dropHead).map { case (_, tpe) => getClassFromType(tpe) } + getConstructorParameters(t) } /** @@ -617,15 +611,6 @@ object ScalaReflection extends ScalaReflection { } } - def getClassForCaseClass[T: TypeTag]: Option[Class[_]] = { - val tpe = localTypeOf[T] - if (isSubtype(tpe.dealias, localTypeOf[Product])) { - Some(getClassFromType(tpe)) - } else { - None - } - } - /* * Retrieves the runtime class corresponding to the provided type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 051baea1ba0a1..fb81923291f7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -256,9 +256,6 @@ class Analyzer( Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, - ResolveCaseClassForUDF, - // `ResolveCaseClassForUDF` may generates `NewInstance` so we need to resolve it - ResolveNewInstance, HandleNullInputsForUDF), Batch("UpdateNullability", Once, UpdateAttributeNullability), @@ -2744,39 +2741,6 @@ class Analyzer( } } - object ResolveCaseClassForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case p if !p.resolved => p // Skip unresolved nodes. - - case p => p transformExpressionsUp { - - case udf @ ScalaUDF(_, _, inputs, _, inputCaseClass, _, _, _, _) => - if (inputCaseClass.exists(_.isDefined)) { - assert(inputs.size == inputCaseClass.size) - val newInputs = inputs.zip(inputCaseClass).map { - case (input, clazzOpt) => - if (clazzOpt.isDefined) { - val clazz = clazzOpt.get - assert(input.dataType.isInstanceOf[StructType], - s"expects StructType, but got ${input.dataType}") - val dataType = input.dataType.asInstanceOf[StructType] - val args = dataType.toAttributes.zipWithIndex.map { case (a, i) => - GetStructField(input, i, Some(a.name)) - } - NewInstance(clazz, args, ObjectType(clazz)) - } else { - input - } - } - // assign Nil inputCaseClass to avoid applying this rule for multiple times - udf.copy(children = newInputs, inputCaseClass = Nil) - } else { - udf - } - } - } - } - /** * Check and add proper window frames for all window functions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index f19b7278a888a..a20e9c6fbf57b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{AbstractDataType, DataType} @@ -35,9 +36,9 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType} * UDF return null if there is any null input value of these types. On the * other hand, Java UDFs can only have boxed types, thus this parameter will * always be all false. - * @param inputCaseClass Includes the Class[_] of case classes from the input parameter. - * If the input parameter is not a case class, then, the corresponding value - * is None. + * @param inputEncoders ExpressionEncoder for each input parameters. For a input parameter which + * serialized as struct will use encoder instead of CatalystTypeConverters to + * convert internal value to Scala value. * @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do * not want to perform coercion, simply use "Nil". Note that it would've been * better to use Option of Seq[DataType] so we can use "None" as the case for no @@ -52,7 +53,7 @@ case class ScalaUDF( dataType: DataType, children: Seq[Expression], inputPrimitives: Seq[Boolean], - inputCaseClass: Seq[Option[Class[_]]] = Nil, + inputEncoders: Seq[ExpressionEncoder[_]] = Nil, inputTypes: Seq[AbstractDataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, @@ -63,6 +64,14 @@ case class ScalaUDF( override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})" + private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = { + val encoder = inputEncoders(i) + encoder.isSerializedAsStructForTopLevel match { + case true => r: Any => encoder.resolveAndBind().fromRow(r.asInstanceOf[InternalRow]) + case false => CatalystTypeConverters.createToScalaConverter(dataType) + } + } + // scalastyle:off line.size.limit /** This method has been generated by this script @@ -70,7 +79,7 @@ case class ScalaUDF( (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) - val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val converters = (0 to x - 1).map(x => s"lazy val converter$x = createToScalaConverter($x, child$x.dataType)").reduce(_ + "\n " + _) val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) s"""case $x => @@ -95,7 +104,7 @@ case class ScalaUDF( case 1 => val func = function.asInstanceOf[(Any) => Any] val child0 = children(0) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) (input: InternalRow) => { func( converter0(child0.eval(input))) @@ -105,8 +114,8 @@ case class ScalaUDF( val func = function.asInstanceOf[(Any, Any) => Any] val child0 = children(0) val child1 = children(1) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -118,9 +127,9 @@ case class ScalaUDF( val child0 = children(0) val child1 = children(1) val child2 = children(2) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -134,10 +143,10 @@ case class ScalaUDF( val child1 = children(1) val child2 = children(2) val child3 = children(3) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -153,11 +162,11 @@ case class ScalaUDF( val child2 = children(2) val child3 = children(3) val child4 = children(4) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -175,12 +184,12 @@ case class ScalaUDF( val child3 = children(3) val child4 = children(4) val child5 = children(5) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -200,13 +209,13 @@ case class ScalaUDF( val child4 = children(4) val child5 = children(5) val child6 = children(6) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -228,14 +237,14 @@ case class ScalaUDF( val child5 = children(5) val child6 = children(6) val child7 = children(7) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -259,15 +268,15 @@ case class ScalaUDF( val child6 = children(6) val child7 = children(7) val child8 = children(8) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -293,16 +302,16 @@ case class ScalaUDF( val child7 = children(7) val child8 = children(8) val child9 = children(9) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -330,17 +339,17 @@ case class ScalaUDF( val child8 = children(8) val child9 = children(9) val child10 = children(10) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -370,18 +379,18 @@ case class ScalaUDF( val child9 = children(9) val child10 = children(10) val child11 = children(11) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -413,19 +422,19 @@ case class ScalaUDF( val child10 = children(10) val child11 = children(11) val child12 = children(12) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -459,20 +468,20 @@ case class ScalaUDF( val child11 = children(11) val child12 = children(12) val child13 = children(13) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) + lazy val converter13 = createToScalaConverter(13, child13.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -508,21 +517,21 @@ case class ScalaUDF( val child12 = children(12) val child13 = children(13) val child14 = children(14) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) + lazy val converter13 = createToScalaConverter(13, child13.dataType) + lazy val converter14 = createToScalaConverter(14, child14.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -560,22 +569,22 @@ case class ScalaUDF( val child13 = children(13) val child14 = children(14) val child15 = children(15) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) - lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) + lazy val converter13 = createToScalaConverter(13, child13.dataType) + lazy val converter14 = createToScalaConverter(14, child14.dataType) + lazy val converter15 = createToScalaConverter(15, child15.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -615,23 +624,23 @@ case class ScalaUDF( val child14 = children(14) val child15 = children(15) val child16 = children(16) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) - lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) - lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) + lazy val converter13 = createToScalaConverter(13, child13.dataType) + lazy val converter14 = createToScalaConverter(14, child14.dataType) + lazy val converter15 = createToScalaConverter(15, child15.dataType) + lazy val converter16 = createToScalaConverter(16, child16.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -673,24 +682,24 @@ case class ScalaUDF( val child15 = children(15) val child16 = children(16) val child17 = children(17) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) - lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) - lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) - lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) + lazy val converter13 = createToScalaConverter(13, child13.dataType) + lazy val converter14 = createToScalaConverter(14, child14.dataType) + lazy val converter15 = createToScalaConverter(15, child15.dataType) + lazy val converter16 = createToScalaConverter(16, child16.dataType) + lazy val converter17 = createToScalaConverter(17, child17.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -734,25 +743,25 @@ case class ScalaUDF( val child16 = children(16) val child17 = children(17) val child18 = children(18) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) - lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) - lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) - lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) - lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) + lazy val converter13 = createToScalaConverter(13, child13.dataType) + lazy val converter14 = createToScalaConverter(14, child14.dataType) + lazy val converter15 = createToScalaConverter(15, child15.dataType) + lazy val converter16 = createToScalaConverter(16, child16.dataType) + lazy val converter17 = createToScalaConverter(17, child17.dataType) + lazy val converter18 = createToScalaConverter(18, child18.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -798,26 +807,26 @@ case class ScalaUDF( val child17 = children(17) val child18 = children(18) val child19 = children(19) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) - lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) - lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) - lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) - lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) - lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) + lazy val converter13 = createToScalaConverter(13, child13.dataType) + lazy val converter14 = createToScalaConverter(14, child14.dataType) + lazy val converter15 = createToScalaConverter(15, child15.dataType) + lazy val converter16 = createToScalaConverter(16, child16.dataType) + lazy val converter17 = createToScalaConverter(17, child17.dataType) + lazy val converter18 = createToScalaConverter(18, child18.dataType) + lazy val converter19 = createToScalaConverter(19, child19.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -865,27 +874,27 @@ case class ScalaUDF( val child18 = children(18) val child19 = children(19) val child20 = children(20) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) - lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) - lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) - lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) - lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) - lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) - lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) + lazy val converter13 = createToScalaConverter(13, child13.dataType) + lazy val converter14 = createToScalaConverter(14, child14.dataType) + lazy val converter15 = createToScalaConverter(15, child15.dataType) + lazy val converter16 = createToScalaConverter(16, child16.dataType) + lazy val converter17 = createToScalaConverter(17, child17.dataType) + lazy val converter18 = createToScalaConverter(18, child18.dataType) + lazy val converter19 = createToScalaConverter(19, child19.dataType) + lazy val converter20 = createToScalaConverter(20, child20.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -935,28 +944,28 @@ case class ScalaUDF( val child19 = children(19) val child20 = children(20) val child21 = children(21) - lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) - lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) - lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) - lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) - lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) - lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) - lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) - lazy val converter21 = CatalystTypeConverters.createToScalaConverter(child21.dataType) + lazy val converter0 = createToScalaConverter(0, child0.dataType) + lazy val converter1 = createToScalaConverter(1, child1.dataType) + lazy val converter2 = createToScalaConverter(2, child2.dataType) + lazy val converter3 = createToScalaConverter(3, child3.dataType) + lazy val converter4 = createToScalaConverter(4, child4.dataType) + lazy val converter5 = createToScalaConverter(5, child5.dataType) + lazy val converter6 = createToScalaConverter(6, child6.dataType) + lazy val converter7 = createToScalaConverter(7, child7.dataType) + lazy val converter8 = createToScalaConverter(8, child8.dataType) + lazy val converter9 = createToScalaConverter(9, child9.dataType) + lazy val converter10 = createToScalaConverter(10, child10.dataType) + lazy val converter11 = createToScalaConverter(11, child11.dataType) + lazy val converter12 = createToScalaConverter(12, child12.dataType) + lazy val converter13 = createToScalaConverter(13, child13.dataType) + lazy val converter14 = createToScalaConverter(14, child14.dataType) + lazy val converter15 = createToScalaConverter(15, child15.dataType) + lazy val converter16 = createToScalaConverter(16, child16.dataType) + lazy val converter17 = createToScalaConverter(17, child17.dataType) + lazy val converter18 = createToScalaConverter(18, child18.dataType) + lazy val converter19 = createToScalaConverter(19, child19.dataType) + lazy val converter20 = createToScalaConverter(20, child20.dataType) + lazy val converter21 = createToScalaConverter(21, child21.dataType) (input: InternalRow) => { func( converter0(child0.eval(input)), @@ -991,8 +1000,8 @@ case class ScalaUDF( val converterClassName = classOf[Any => Any].getName // The type converters for inputs and the result. - val converters: Array[Any => Any] = children.map { c => - CatalystTypeConverters.createToScalaConverter(c.dataType) + val converters: Array[Any => Any] = children.zipWithIndex.map { case (c, i) => + createToScalaConverter(i, c.dataType) }.toArray :+ CatalystTypeConverters.createToCatalystConverter(dataType) val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]") val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 03942623eb696..54abd09d89ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -448,20 +448,8 @@ case class NewInstance( childrenResolved && !needOuterPointer } - private def argConverters(): Seq[Any => Any] = { - val inputTypes = ScalaReflection.expressionJavaClasses(arguments) - val neededTypes = ScalaReflection.getConstructorParameters(cls) - arguments.zip(inputTypes).zip(neededTypes).map { case ((arg, input), needed) => - if (needed.isAssignableFrom(input)) { - identity[Any] _ - } else { - CatalystTypeConverters.createToScalaConverter(arg.dataType) - } - } - } - @transient private lazy val constructor: (Seq[AnyRef]) => Any = { - val paramTypes = ScalaReflection.getConstructorParameters(cls) + val paramTypes = ScalaReflection.expressionJavaClasses(arguments) val getConstructor = (paramClazz: Seq[Class[_]]) => { ScalaReflection.findConstructor(cls, paramClazz).getOrElse { sys.error(s"Couldn't find a valid constructor on $cls") @@ -484,10 +472,6 @@ case class NewInstance( override def eval(input: InternalRow): Any = { val argValues = arguments.map(_.eval(input)) - .zip(argConverters()) - .map { case (arg, converter) => - converter(arg) - } constructor(argValues.map(_.asInstanceOf[AnyRef])) } @@ -496,20 +480,6 @@ case class NewInstance( val (argCode, argString, resultIsNull) = prepareArguments(ctx) - val converterClassName = classOf[Any => Any].getName - val convertersTerm = ctx.addReferenceObj( - "converters", argConverters().toArray, s"$converterClassName[]") - val argTypes = ScalaReflection.getConstructorParameters(cls) - val convertedArgs = argTypes.map { a => - ctx.addMutableState(CodeGenerator.boxedType(a.getSimpleName), "convertedArg") - } - val convertedCode = argString.split(",").zip(argTypes).zipWithIndex.map { - case ((arg, tpe), i) => - s"${convertedArgs(i)} = " + - s"(${CodeGenerator.boxedType(tpe.getSimpleName)}) $convertersTerm[$i].apply($arg);" - }.mkString("\n") - val convertedArgString = convertedArgs.mkString(",") - val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) ev.isNull = resultIsNull @@ -518,17 +488,16 @@ case class NewInstance( // If there are no constructors, the `new` method will fail. In // this case we can try to call the apply method constructor // that might be defined on the companion object. - case 0 => s"$className$$.MODULE$$.apply($convertedArgString)" + case 0 => s"$className$$.MODULE$$.apply($argString)" case _ => outer.map { gen => - s"${gen.value}.new ${cls.getSimpleName}($convertedArgString)" + s"${gen.value}.new ${cls.getSimpleName}($argString)" }.getOrElse { - s"new $className($convertedArgString)" + s"new $className($argString)" } } val code = code""" $argCode - $convertedCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $constructorCall; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 71bce227929e6..ebbee7296e008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -181,8 +181,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Nil - val inputCaseClass: Seq[Option[Class[_]]] = Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClass).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 0) { finalUdf.createScalaUDF(e) @@ -202,8 +202,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 1) { finalUdf.createScalaUDF(e) @@ -223,8 +223,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 2) { finalUdf.createScalaUDF(e) @@ -244,8 +244,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 3) { finalUdf.createScalaUDF(e) @@ -265,8 +265,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 4) { finalUdf.createScalaUDF(e) @@ -286,8 +286,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 5) { finalUdf.createScalaUDF(e) @@ -307,8 +307,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 6) { finalUdf.createScalaUDF(e) @@ -328,8 +328,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 7) { finalUdf.createScalaUDF(e) @@ -349,8 +349,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 8) { finalUdf.createScalaUDF(e) @@ -370,8 +370,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 9) { finalUdf.createScalaUDF(e) @@ -391,8 +391,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 10) { finalUdf.createScalaUDF(e) @@ -412,8 +412,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 11) { finalUdf.createScalaUDF(e) @@ -433,8 +433,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 12) { finalUdf.createScalaUDF(e) @@ -454,8 +454,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 13) { finalUdf.createScalaUDF(e) @@ -475,8 +475,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 14) { finalUdf.createScalaUDF(e) @@ -496,8 +496,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 15) { finalUdf.createScalaUDF(e) @@ -517,8 +517,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 16) { finalUdf.createScalaUDF(e) @@ -538,8 +538,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 17) { finalUdf.createScalaUDF(e) @@ -559,8 +559,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 18) { finalUdf.createScalaUDF(e) @@ -580,8 +580,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: ScalaReflection.getClassForCaseClass[A19] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: ExpressionEncoder[A19]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 19) { finalUdf.createScalaUDF(e) @@ -601,8 +601,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: ScalaReflection.getClassForCaseClass[A19] :: ScalaReflection.getClassForCaseClass[A20] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: ExpressionEncoder[A19]() :: ExpressionEncoder[A20]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 20) { finalUdf.createScalaUDF(e) @@ -622,8 +622,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: ScalaReflection.getClassForCaseClass[A19] :: ScalaReflection.getClassForCaseClass[A20] :: ScalaReflection.getClassForCaseClass[A21] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: ExpressionEncoder[A19]() :: ExpressionEncoder[A20]() :: ExpressionEncoder[A21]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 21) { finalUdf.createScalaUDF(e) @@ -643,8 +643,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Try(ScalaReflection.schemaFor[A22]).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: ScalaReflection.getClassForCaseClass[A11] :: ScalaReflection.getClassForCaseClass[A12] :: ScalaReflection.getClassForCaseClass[A13] :: ScalaReflection.getClassForCaseClass[A14] :: ScalaReflection.getClassForCaseClass[A15] :: ScalaReflection.getClassForCaseClass[A16] :: ScalaReflection.getClassForCaseClass[A17] :: ScalaReflection.getClassForCaseClass[A18] :: ScalaReflection.getClassForCaseClass[A19] :: ScalaReflection.getClassForCaseClass[A20] :: ScalaReflection.getClassForCaseClass[A21] :: ScalaReflection.getClassForCaseClass[A22] :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputCaseClasses).withName(name) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: ExpressionEncoder[A19]() :: ExpressionEncoder[A20]() :: ExpressionEncoder[A21]() :: ExpressionEncoder[A22]() :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 22) { finalUdf.createScalaUDF(e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 8455b0d593ee5..576dc9d4d8065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -94,7 +94,7 @@ private[spark] case class SparkUserDefinedFunction( f: AnyRef, dataType: DataType, inputSchemas: Seq[Option[ScalaReflection.Schema]], - inputCaseClass: Seq[Option[Class[_]]] = Nil, + inputEncoders: Seq[ExpressionEncoder[_]] = Nil, name: Option[String] = None, nullable: Boolean = true, deterministic: Boolean = true) extends UserDefinedFunction { @@ -116,7 +116,7 @@ private[spark] case class SparkUserDefinedFunction( dataType, exprs, inputsPrimitive, - inputCaseClass, + inputEncoders, inputTypes, udfName = name, nullable = nullable, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 70fd8cdeb495a..e26f83eb267c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4388,8 +4388,8 @@ object functions { def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Nil - val inputCaseClasses: Seq[Option[Class[_]]] = Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4405,8 +4405,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4422,8 +4422,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4439,8 +4439,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4456,8 +4456,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4473,8 +4473,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4490,8 +4490,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4507,8 +4507,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4524,8 +4524,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4541,8 +4541,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4558,8 +4558,8 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil - val inputCaseClasses: Seq[Option[Class[_]]] = ScalaReflection.getClassForCaseClass[A1] :: ScalaReflection.getClassForCaseClass[A2] :: ScalaReflection.getClassForCaseClass[A3] :: ScalaReflection.getClassForCaseClass[A4] :: ScalaReflection.getClassForCaseClass[A5] :: ScalaReflection.getClassForCaseClass[A6] :: ScalaReflection.getClassForCaseClass[A7] :: ScalaReflection.getClassForCaseClass[A8] :: ScalaReflection.getClassForCaseClass[A9] :: ScalaReflection.getClassForCaseClass[A10] :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputCaseClasses) + val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } From 867ad06f8b7068b8cd24970fc45d8fd8dc12428d Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 18 Mar 2020 23:04:51 +0800 Subject: [PATCH 06/11] fix Any and untyped Scala UDF --- .../catalyst/encoders/ExpressionEncoder.scala | 13 ++++++ .../sql/catalyst/expressions/ScalaUDF.scala | 16 +++++-- .../apache/spark/sql/UDFRegistration.scala | 46 +++++++++---------- .../sql/expressions/UserDefinedFunction.scala | 2 +- .../org/apache/spark/sql/functions.scala | 22 ++++----- .../scala/org/apache/spark/sql/UDFSuite.scala | 11 ++++- 6 files changed, 68 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b820cb1a5c522..43d5acf6c455e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -60,6 +60,19 @@ object ExpressionEncoder { ClassTag[T](cls)) } + /** + * Unlike apply(), this method return None instead of throwing exception + * when there's no encoder found for the type `T`. This's mainly used for + * typed Scala UDF to workaround 'Any' type. + */ + def applyOption[T : TypeTag](): Option[ExpressionEncoder[T]] = { + try { + Option(ExpressionEncoder[T]()) + } catch { + case _: Exception => None + } + } + // TODO: improve error message for java bean encoder. def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { val schema = JavaTypeInference.inferDataType(beanClass)._1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index a20e9c6fbf57b..ff80c6f87132b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -53,7 +53,7 @@ case class ScalaUDF( dataType: DataType, children: Seq[Expression], inputPrimitives: Seq[Boolean], - inputEncoders: Seq[ExpressionEncoder[_]] = Nil, + inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, inputTypes: Seq[AbstractDataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, @@ -65,10 +65,16 @@ case class ScalaUDF( override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})" private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = { - val encoder = inputEncoders(i) - encoder.isSerializedAsStructForTopLevel match { - case true => r: Any => encoder.resolveAndBind().fromRow(r.asInstanceOf[InternalRow]) - case false => CatalystTypeConverters.createToScalaConverter(dataType) + inputEncoders.length match { + case 0 => + // for untyped Scala UDF + CatalystTypeConverters.createToScalaConverter(dataType) + case _ => + val encoder = inputEncoders(i) + encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel match { + case true => r: Any => encoder.get.resolveAndBind().fromRow(r.asInstanceOf[InternalRow]) + case false => CatalystTypeConverters.createToScalaConverter(dataType) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ebbee7296e008..03b73bf5e04aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -181,7 +181,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 0) { @@ -202,7 +202,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 1) { @@ -223,7 +223,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 2) { @@ -244,7 +244,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 3) { @@ -265,7 +265,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 4) { @@ -286,7 +286,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 5) { @@ -307,7 +307,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 6) { @@ -328,7 +328,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 7) { @@ -349,7 +349,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 8) { @@ -370,7 +370,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 9) { @@ -391,7 +391,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 10) { @@ -412,7 +412,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 11) { @@ -433,7 +433,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 12) { @@ -454,7 +454,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 13) { @@ -475,7 +475,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 14) { @@ -496,7 +496,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 15) { @@ -517,7 +517,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 16) { @@ -538,7 +538,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 17) { @@ -559,7 +559,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 18) { @@ -580,7 +580,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: ExpressionEncoder[A19]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: ExpressionEncoder.applyOption[A19]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 19) { @@ -601,7 +601,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: ExpressionEncoder[A19]() :: ExpressionEncoder[A20]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: ExpressionEncoder.applyOption[A19]() :: ExpressionEncoder.applyOption[A20]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 20) { @@ -622,7 +622,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: ExpressionEncoder[A19]() :: ExpressionEncoder[A20]() :: ExpressionEncoder[A21]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: ExpressionEncoder.applyOption[A19]() :: ExpressionEncoder.applyOption[A20]() :: ExpressionEncoder.applyOption[A21]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 21) { @@ -643,7 +643,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Try(ScalaReflection.schemaFor[A22]).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: ExpressionEncoder[A11]() :: ExpressionEncoder[A12]() :: ExpressionEncoder[A13]() :: ExpressionEncoder[A14]() :: ExpressionEncoder[A15]() :: ExpressionEncoder[A16]() :: ExpressionEncoder[A17]() :: ExpressionEncoder[A18]() :: ExpressionEncoder[A19]() :: ExpressionEncoder[A20]() :: ExpressionEncoder[A21]() :: ExpressionEncoder[A22]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: ExpressionEncoder.applyOption[A19]() :: ExpressionEncoder.applyOption[A20]() :: ExpressionEncoder.applyOption[A21]() :: ExpressionEncoder.applyOption[A22]() :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 22) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 576dc9d4d8065..52d96088fd943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -94,7 +94,7 @@ private[spark] case class SparkUserDefinedFunction( f: AnyRef, dataType: DataType, inputSchemas: Seq[Option[ScalaReflection.Schema]], - inputEncoders: Seq[ExpressionEncoder[_]] = Nil, + inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, name: Option[String] = None, nullable: Boolean = true, deterministic: Boolean = true) extends UserDefinedFunction { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e26f83eb267c8..9ef7407fb6dc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4388,7 +4388,7 @@ object functions { def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4405,7 +4405,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4422,7 +4422,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4439,7 +4439,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4456,7 +4456,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4473,7 +4473,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4490,7 +4490,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4507,7 +4507,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4524,7 +4524,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4541,7 +4541,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4558,7 +4558,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil - val inputEncoders: Seq[ExpressionEncoder[_]] = ExpressionEncoder[A1]() :: ExpressionEncoder[A2]() :: ExpressionEncoder[A3]() :: ExpressionEncoder[A4]() :: ExpressionEncoder[A5]() :: ExpressionEncoder[A6]() :: ExpressionEncoder[A7]() :: ExpressionEncoder[A8]() :: ExpressionEncoder[A9]() :: ExpressionEncoder[A10]() :: Nil + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) if (nullable) udf else udf.asNonNullable() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 5f33614f50436..99352f04e10c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -572,10 +572,17 @@ class UDFSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(500) :: Nil) } - test("input case class parameter and return case class ") { - val f = (d1: TestData) => TestData(d1.key * 2, "copy") + test("input case class parameter and return case class") { + val f = (d: TestData) => TestData(d.key * 2, "copy") val myUdf = udf(f) val df = Seq(("data", TestData(50, "d2"))).toDF("col1", "col2") checkAnswer(df.select(myUdf(Column("col2"))), Row(Row(100, "copy")) :: Nil) } + + test("any and case class") { + val f = (any: Any, d: TestData) => s"${any.toString}, ${d.value}" + val myUdf = udf(f) + val df = Seq(("Hello", TestData(50, "World"))).toDF("col1", "col2") + checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row("Hello, World") :: Nil) + } } From 23ca0988637571a4b1e210acbbee7c34c927d56b Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 19 Mar 2020 10:56:39 +0800 Subject: [PATCH 07/11] address comment --- .../sql/catalyst/expressions/ScalaUDF.scala | 23 ++-- .../apache/spark/sql/UDFRegistration.scala | 119 +++++++----------- .../sql/expressions/UserDefinedFunction.scala | 24 ++-- .../org/apache/spark/sql/functions.scala | 87 ++++++------- .../spark/sql/IntegratedUDFTestUtils.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 10 +- 6 files changed, 127 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index ff80c6f87132b..87ffd74445663 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -62,19 +64,22 @@ case class ScalaUDF( override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + private lazy val resolvedEnc = mutable.HashMap[Int, ExpressionEncoder[_]]() + override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})" private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = { - inputEncoders.length match { - case 0 => - // for untyped Scala UDF + if (inputEncoders.isEmpty) { + // for untyped Scala UDF + CatalystTypeConverters.createToScalaConverter(dataType) + } else { + val encoder = inputEncoders(i) + if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) { + val enc = resolvedEnc.getOrElseUpdate(i, encoder.get.resolveAndBind()) + row: Any => enc.fromRow(row.asInstanceOf[InternalRow]) + } else { CatalystTypeConverters.createToScalaConverter(dataType) - case _ => - val encoder = inputEncoders(i) - encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel match { - case true => r: Any => encoder.get.resolveAndBind().fromRow(r.asInstanceOf[InternalRow]) - case false => CatalystTypeConverters.createToScalaConverter(dataType) - } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 03b73bf5e04aa..c411a64f6229c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -125,7 +125,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"Try(ScalaReflection.schemaFor[A$i]).toOption :: $s"}) + val inputEncoders = (1 to x).foldRight("Nil")((i, s) => {s"Try(ExpressionEncoder[A$i]()).toOption :: $s"}) println(s""" |/** | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). @@ -134,8 +134,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputSchemas: Seq[Option[ScalaReflection.Schema]] = $inputSchemas - | val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + | val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = $inputEncoders + | val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) | val finalUdf = if (nullable) udf else udf.asNonNullable() | def builder(e: Seq[Expression]) = if (e.length == $x) { | finalUdf.createScalaUDF(e) @@ -180,9 +180,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Nil val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 0) { finalUdf.createScalaUDF(e) @@ -201,9 +200,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 1) { finalUdf.createScalaUDF(e) @@ -222,9 +220,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 2) { finalUdf.createScalaUDF(e) @@ -243,9 +240,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 3) { finalUdf.createScalaUDF(e) @@ -264,9 +260,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 4) { finalUdf.createScalaUDF(e) @@ -285,9 +280,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 5) { finalUdf.createScalaUDF(e) @@ -306,9 +300,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 6) { finalUdf.createScalaUDF(e) @@ -327,9 +320,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 7) { finalUdf.createScalaUDF(e) @@ -348,9 +340,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 8) { finalUdf.createScalaUDF(e) @@ -369,9 +360,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 9) { finalUdf.createScalaUDF(e) @@ -390,9 +380,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 10) { finalUdf.createScalaUDF(e) @@ -411,9 +400,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 11) { finalUdf.createScalaUDF(e) @@ -432,9 +420,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 12) { finalUdf.createScalaUDF(e) @@ -453,9 +440,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 13) { finalUdf.createScalaUDF(e) @@ -474,9 +460,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 14) { finalUdf.createScalaUDF(e) @@ -495,9 +480,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 15) { finalUdf.createScalaUDF(e) @@ -516,9 +500,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 16) { finalUdf.createScalaUDF(e) @@ -537,9 +520,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 17) { finalUdf.createScalaUDF(e) @@ -558,9 +540,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 18) { finalUdf.createScalaUDF(e) @@ -579,9 +560,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: ExpressionEncoder.applyOption[A19]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 19) { finalUdf.createScalaUDF(e) @@ -600,9 +580,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: ExpressionEncoder.applyOption[A19]() :: ExpressionEncoder.applyOption[A20]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 20) { finalUdf.createScalaUDF(e) @@ -621,9 +600,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: ExpressionEncoder.applyOption[A19]() :: ExpressionEncoder.applyOption[A20]() :: ExpressionEncoder.applyOption[A21]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 21) { finalUdf.createScalaUDF(e) @@ -642,9 +620,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Try(ScalaReflection.schemaFor[A22]).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: ExpressionEncoder.applyOption[A11]() :: ExpressionEncoder.applyOption[A12]() :: ExpressionEncoder.applyOption[A13]() :: ExpressionEncoder.applyOption[A14]() :: ExpressionEncoder.applyOption[A15]() :: ExpressionEncoder.applyOption[A16]() :: ExpressionEncoder.applyOption[A17]() :: ExpressionEncoder.applyOption[A18]() :: ExpressionEncoder.applyOption[A19]() :: ExpressionEncoder.applyOption[A20]() :: ExpressionEncoder.applyOption[A21]() :: ExpressionEncoder.applyOption[A22]() :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputSchemas, inputEncoders).withName(name) + val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Try(ExpressionEncoder[A22]()).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 22) { finalUdf.createScalaUDF(e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 52d96088fd943..e9a289483e6b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -93,7 +93,6 @@ sealed abstract class UserDefinedFunction { private[spark] case class SparkUserDefinedFunction( f: AnyRef, dataType: DataType, - inputSchemas: Seq[Option[ScalaReflection.Schema]], inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, name: Option[String] = None, nullable: Boolean = true, @@ -105,12 +104,23 @@ private[spark] case class SparkUserDefinedFunction( } private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = { - // It's possible that some of the inputs don't have a specific type(e.g. `Any`), skip type - // check. - val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType)) - // `ScalaReflection.Schema.nullable` is false iff the type is primitive. Also `Any` is not - // primitive. - val inputsPrimitive = inputSchemas.map(_.map(!_.nullable).getOrElse(false)) + // It's possible that some of the inputs don't have a specific encoder(e.g. `Any`). + // And `nullable` is false iff the type is primitive. Also `Any` is not primitive. + val (inputTypes, inputsPrimitive) = inputEncoders.map { encoderOpt => + if (encoderOpt.isDefined) { + val encoder = encoderOpt.get + if (encoder.isSerializedAsStruct) { + // struct type is not primitive + (encoder.schema, false) + } else { + val field = encoder.schema.head + (field.dataType, !field.nullable) + } + } else { + (AnyDataType, false) + } + }.unzip + ScalaUDF( f, dataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9ef7407fb6dc7..7aa9a1a725407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4268,7 +4268,7 @@ object functions { (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"Try(ScalaReflection.schemaFor(typeTag[A$i])).toOption :: $s"}) + val inputEncoders = (1 to x).foldRight("Nil")((i, s) => {s"Try(ExpressionEncoder[A$i]()).toOption :: $s"}) println(s""" |/** | * Defines a Scala closure of $x arguments as user-defined function (UDF). @@ -4281,8 +4281,8 @@ object functions { | */ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputSchemas = $inputSchemas - | val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) + | val inputEncoders = $inputEncoders + | val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) | if (nullable) udf else udf.asNonNullable() |}""".stripMargin) } @@ -4305,7 +4305,7 @@ object functions { | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { | val func = $funcCall - | SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill($i)(None)) + | SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill($i)(None)) |}""".stripMargin) } @@ -4387,9 +4387,8 @@ object functions { */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4404,9 +4403,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4421,9 +4419,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4438,9 +4435,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4455,9 +4451,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4472,9 +4467,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4489,9 +4483,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4506,9 +4499,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4523,9 +4515,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4540,9 +4531,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4557,9 +4547,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = ExpressionEncoder.applyOption[A1]() :: ExpressionEncoder.applyOption[A2]() :: ExpressionEncoder.applyOption[A3]() :: ExpressionEncoder.applyOption[A4]() :: ExpressionEncoder.applyOption[A5]() :: ExpressionEncoder.applyOption[A6]() :: ExpressionEncoder.applyOption[A7]() :: ExpressionEncoder.applyOption[A8]() :: ExpressionEncoder.applyOption[A9]() :: ExpressionEncoder.applyOption[A10]() :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputSchemas, inputEncoders) + val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil + val udf = SparkUserDefinedFunction(f, dataType, inputEncoders) if (nullable) udf else udf.asNonNullable() } @@ -4578,7 +4567,7 @@ object functions { */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { val func = () => f.asInstanceOf[UDF0[Any]].call() - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(0)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(0)(None)) } /** @@ -4592,7 +4581,7 @@ object functions { */ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(1)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(1)(None)) } /** @@ -4606,7 +4595,7 @@ object functions { */ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(2)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(2)(None)) } /** @@ -4620,7 +4609,7 @@ object functions { */ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(3)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(3)(None)) } /** @@ -4634,7 +4623,7 @@ object functions { */ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(4)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(4)(None)) } /** @@ -4648,7 +4637,7 @@ object functions { */ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(5)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(5)(None)) } /** @@ -4662,7 +4651,7 @@ object functions { */ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(6)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(6)(None)) } /** @@ -4676,7 +4665,7 @@ object functions { */ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(7)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(7)(None)) } /** @@ -4690,7 +4679,7 @@ object functions { */ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(8)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(8)(None)) } /** @@ -4704,7 +4693,7 @@ object functions { */ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(9)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(9)(None)) } /** @@ -4718,7 +4707,7 @@ object functions { */ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(10)(None)) + SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(10)(None)) } // scalastyle:on parameter.number @@ -4756,7 +4745,7 @@ object functions { s"caution." throw new AnalysisException(errorMsg) } - SparkUserDefinedFunction(f, dataType, inputSchemas = Nil) + SparkUserDefinedFunction(f, dataType, inputEncoders = Nil) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 51150a1b38b49..4a4504a075060 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -337,7 +337,7 @@ object IntegratedUDFTestUtils extends SQLHelper { input.toString }, StringType, - inputSchemas = Seq.fill(1)(None), + inputEncoders = Seq.fill(1)(None), name = Some(name)) { override def apply(exprs: Column*): Column = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 99352f04e10c2..08f41f6819a0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -579,10 +579,18 @@ class UDFSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(myUdf(Column("col2"))), Row(Row(100, "copy")) :: Nil) } - test("any and case class") { + test("any and case class parameter") { val f = (any: Any, d: TestData) => s"${any.toString}, ${d.value}" val myUdf = udf(f) val df = Seq(("Hello", TestData(50, "World"))).toDF("col1", "col2") checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row("Hello, World") :: Nil) } + + test("nested case class parameter") { + val f = (y: Int, training: TrainingSales) => training.sales.year + y + val myUdf = udf(f) + val df = Seq((20, TrainingSales("training", CourseSales("course", 2000, 3.14)))) + .toDF("col1", "col2") + checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(2020) :: Nil) + } } From 842d6fa7453d0cd34a41ebf2eb13c93c899ad83d Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 19 Mar 2020 22:58:33 +0800 Subject: [PATCH 08/11] use encoder only --- .../sql/catalyst/analysis/Analyzer.scala | 8 +-- .../sql/catalyst/expressions/ScalaUDF.scala | 58 +++++++++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 ++++-- .../catalyst/expressions/ScalaUDFSuite.scala | 18 ++++-- .../optimizer/EliminateSortsSuite.scala | 4 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 4 +- .../apache/spark/sql/UDFRegistration.scala | 48 +++++++-------- .../datasources/FileFormatDataWriter.scala | 3 +- .../sql/expressions/UserDefinedFunction.scala | 19 ------ 9 files changed, 106 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fb81923291f7f..95e86512bb04c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2707,13 +2707,13 @@ class Analyzer( case p => p transformExpressionsUp { - case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _, _) - if inputPrimitives.contains(true) => + case udf @ ScalaUDF(_, _, inputs, _, _, _, _) + if udf.inputPrimitives.contains(true) => // Otherwise, add special handling of null for fields that can't accept null. // The result of operations like this, when passed null, is generally to return null. - assert(inputPrimitives.length == inputs.length) + assert(udf.inputPrimitives.length == inputs.length) - val inputPrimitivesPair = inputPrimitives.zip(inputs) + val inputPrimitivesPair = udf.inputPrimitives.zip(inputs) val inputNullCheck = inputPrimitivesPair.collect { case (isPrimitive, input) if isPrimitive && input.nullable => IsNull(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 87ffd74445663..59cecd519ca66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.{AbstractDataType, DataType} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType} /** * User-defined function. @@ -34,17 +34,9 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType} * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. * @param children The input expressions of this UDF. - * @param inputPrimitives The analyzer should be aware of Scala primitive types so as to make the - * UDF return null if there is any null input value of these types. On the - * other hand, Java UDFs can only have boxed types, thus this parameter will - * always be all false. * @param inputEncoders ExpressionEncoder for each input parameters. For a input parameter which * serialized as struct will use encoder instead of CatalystTypeConverters to * convert internal value to Scala value. - * @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do - * not want to perform coercion, simply use "Nil". Note that it would've been - * better to use Option of Seq[DataType] so we can use "None" as the case for no - * type coercion. However, that would require more refactoring of the codebase. * @param udfName The user-specified name of this UDF. * @param nullable True if the UDF can return null value. * @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result @@ -54,9 +46,7 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputPrimitives: Seq[Boolean], inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, - inputTypes: Seq[AbstractDataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, udfDeterministic: Boolean = true) @@ -68,6 +58,52 @@ case class ScalaUDF( override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})" + /** + * The analyzer should be aware of Scala primitive types so as to make the + * UDF return null if there is any null input value of these types. On the + * other hand, Java UDFs can only have boxed types, thus this parameter will + * always be all false. + */ + def inputPrimitives: Seq[Boolean] = { + inputEncoders.map { encoderOpt => + // It's possible that some of the inputs don't have a specific encoder(e.g. `Any`) + if (encoderOpt.isDefined) { + val encoder = encoderOpt.get + if (encoder.isSerializedAsStruct) { + // struct type is not primitive + false + } else { + // `nullable` is false iff the type is primitive + !encoder.schema.head.nullable + } + } else { + // Any type is not primitive + false + } + } + } + + /** + * The expected input types of this UDF, used to perform type coercion. If we do + * not want to perform coercion, simply use "Nil". Note that it would've been + * better to use Option of Seq[DataType] so we can use "None" as the case for no + * type coercion. However, that would require more refactoring of the codebase. + */ + def inputTypes: Seq[AbstractDataType] = { + inputEncoders.map { encoderOpt => + if (encoderOpt.isDefined) { + val encoder = encoderOpt.get + if (encoder.isSerializedAsStruct) { + encoder.schema + } else { + encoder.schema.head.dataType + } + } else { + AnyDataType + } + } + } + private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = { if (inputEncoders.isEmpty) { // for untyped Scala UDF diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 8451b9b50eff3..02472e153b09e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum} @@ -326,20 +327,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { } // non-primitive parameters do not need special null handling - val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, false :: Nil) + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, + Option(ExpressionEncoder[String]()) :: Nil) val expected1 = udf1 checkUDF(udf1, expected1) // only primitive parameter needs special null handling val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil, - false :: true :: Nil) + Option(ExpressionEncoder[String]()) :: Option(ExpressionEncoder[Double]()) :: Nil) val expected2 = If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil, - true :: true :: Nil) + Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil) val expected3 = If( IsNull(short) || IsNull(double), nullResult, @@ -351,7 +353,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { (s: Short, d: Double) => "x", StringType, short :: nonNullableDouble :: Nil, - true :: true :: Nil) + Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil) val expected4 = If( IsNull(short), nullResult, @@ -362,8 +364,12 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-24891 Fix HandleNullInputsForUDF rule") { val a = testRelation.output(0) val func = (x: Int, y: Int) => x + y - val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, false :: false :: Nil) - val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, false :: false :: Nil) + val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, + Option(ExpressionEncoder[java.lang.Integer]()) :: + Option(ExpressionEncoder[java.lang.Integer]()) :: Nil) + val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, + Option(ExpressionEncoder[java.lang.Integer]()) :: + Option(ExpressionEncoder[java.lang.Integer]()) :: Nil) val plan = Project(Alias(udf2, "")() :: Nil, testRelation) comparePlans(plan.analyze, plan.analyze.analyze) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index c5ffc381b58e2..836b2eaa642a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType} @@ -27,10 +28,12 @@ import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType} class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("basic") { - val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil) + val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, + Option(ExpressionEncoder[Int]()) :: Nil) checkEvaluation(intUdf, 2) - val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil) + val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, + Option(ExpressionEncoder[String]()) :: Nil) checkEvaluation(stringUdf, "ax") } @@ -39,7 +42,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { (s: String) => s.toLowerCase(Locale.ROOT), StringType, Literal.create(null, StringType) :: Nil, - false :: Nil) + Option(ExpressionEncoder[String]()) :: Nil) val e1 = intercept[SparkException](udf.eval()) assert(e1.getMessage.contains("Failed to execute user defined function")) @@ -52,7 +55,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext - ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx) + ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, + Option(ExpressionEncoder[String]()) :: Nil).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } @@ -61,7 +65,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { val udf = ScalaUDF( (a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)), DecimalType.SYSTEM_DEFAULT, - Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil) + Literal(BigDecimal("12345678901234567890.123")) :: Nil, + Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil) val e1 = intercept[ArithmeticException](udf.eval()) assert(e1.getMessage.contains("cannot be represented as Decimal")) val e2 = intercept[SparkException] { @@ -73,7 +78,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { val udf = ScalaUDF( (a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)), DecimalType.SYSTEM_DEFAULT, - Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil) + Literal(BigDecimal("12345678901234567890.123")) :: Nil, + Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil) checkEvaluation(udf, null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index d9a6fbf81de91..d7eb048ba8705 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -244,7 +245,8 @@ class EliminateSortsSuite extends PlanTest { } test("should not remove orderBy in groupBy clause with ScalaUDF as aggs") { - val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil, true :: Nil) + val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil, + Option(ExpressionEncoder[Int]()) :: Nil) val projectPlan = testRelation.select('a, 'b) val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc) val groupByPlan = orderByPlan.groupBy('a)(scalaUdf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e72b2e9b1b214..f5259706325eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions.DslString +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin, SQLHelper} @@ -594,7 +595,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } test("toJSON should not throws java.lang.StackOverflowError") { - val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), false :: Nil) + val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), + Option(ExpressionEncoder[String]()) :: Nil) // Should not throw java.lang.StackOverflowError udf.toJSON } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index c411a64f6229c..ced4af46c3f30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -163,7 +163,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { | val func = $funcCall | def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + | ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $i; Found: " + e.length) @@ -731,7 +731,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF0[_], returnType: DataType): Unit = { val func = () => f.asInstanceOf[UDF0[Any]].call() def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) @@ -746,7 +746,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) @@ -761,7 +761,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) @@ -776,7 +776,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) @@ -791,7 +791,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) @@ -806,7 +806,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) @@ -821,7 +821,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) @@ -836,7 +836,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) @@ -851,7 +851,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) @@ -866,7 +866,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) @@ -881,7 +881,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) @@ -896,7 +896,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) @@ -911,7 +911,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) @@ -926,7 +926,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) @@ -941,7 +941,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) @@ -956,7 +956,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) @@ -971,7 +971,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) @@ -986,7 +986,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) @@ -1001,7 +1001,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) @@ -1016,7 +1016,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) @@ -1031,7 +1031,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) @@ -1046,7 +1046,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) @@ -1061,7 +1061,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 50c4f6cd57a96..edb49d3f90ca3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -182,8 +182,7 @@ class DynamicPartitionDataWriter( val partitionName = ScalaUDF( ExternalCatalogUtils.getPartitionPathString _, StringType, - Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))), - Seq(false, false)) + Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId)))) if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index e9a289483e6b1..2ef6e3d291cef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -104,30 +104,11 @@ private[spark] case class SparkUserDefinedFunction( } private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = { - // It's possible that some of the inputs don't have a specific encoder(e.g. `Any`). - // And `nullable` is false iff the type is primitive. Also `Any` is not primitive. - val (inputTypes, inputsPrimitive) = inputEncoders.map { encoderOpt => - if (encoderOpt.isDefined) { - val encoder = encoderOpt.get - if (encoder.isSerializedAsStruct) { - // struct type is not primitive - (encoder.schema, false) - } else { - val field = encoder.schema.head - (field.dataType, !field.nullable) - } - } else { - (AnyDataType, false) - } - }.unzip - ScalaUDF( f, dataType, exprs, - inputsPrimitive, inputEncoders, - inputTypes, udfName = name, nullable = nullable, udfDeterministic = deterministic) From 174e01753be337ce1b03a28890681fceb47c875c Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 23 Mar 2020 09:29:08 +0800 Subject: [PATCH 09/11] no need resolvedEnc --- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 59cecd519ca66..3259f53072f01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -54,8 +54,6 @@ case class ScalaUDF( override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) - private lazy val resolvedEnc = mutable.HashMap[Int, ExpressionEncoder[_]]() - override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})" /** @@ -111,7 +109,7 @@ case class ScalaUDF( } else { val encoder = inputEncoders(i) if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) { - val enc = resolvedEnc.getOrElseUpdate(i, encoder.get.resolveAndBind()) + val enc = encoder.get.resolveAndBind() row: Any => enc.fromRow(row.asInstanceOf[InternalRow]) } else { CatalystTypeConverters.createToScalaConverter(dataType) From 8e82f3f75a770fc9c6163a483f297eac38c30edd Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 23 Mar 2020 09:29:51 +0800 Subject: [PATCH 10/11] remove applyOption --- .../sql/catalyst/encoders/ExpressionEncoder.scala | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 43d5acf6c455e..b820cb1a5c522 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -60,19 +60,6 @@ object ExpressionEncoder { ClassTag[T](cls)) } - /** - * Unlike apply(), this method return None instead of throwing exception - * when there's no encoder found for the type `T`. This's mainly used for - * typed Scala UDF to workaround 'Any' type. - */ - def applyOption[T : TypeTag](): Option[ExpressionEncoder[T]] = { - try { - Option(ExpressionEncoder[T]()) - } catch { - case _: Exception => None - } - } - // TODO: improve error message for java bean encoder. def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { val schema = JavaTypeInference.inferDataType(beanClass)._1 From b0b298e2d42785c54b1ffb10125741bce7e217e8 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 23 Mar 2020 17:33:39 +0800 Subject: [PATCH 11/11] update comment --- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3259f53072f01..1ac7ca676a876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -59,8 +59,9 @@ case class ScalaUDF( /** * The analyzer should be aware of Scala primitive types so as to make the * UDF return null if there is any null input value of these types. On the - * other hand, Java UDFs can only have boxed types, thus this parameter will - * always be all false. + * other hand, Java UDFs can only have boxed types, thus this will return + * Nil(has same effect with all false) and analyzer will skip null-handling + * on them. */ def inputPrimitives: Seq[Boolean] = { inputEncoders.map { encoderOpt =>