diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index a57673334c10b..6accf1f75064c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -70,15 +70,31 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @param name the name of the UDAF. * @param udaf the UDAF needs to be registered. * @return the registered UDAF. + * + * @since 1.5.0 */ - def register( - name: String, - udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { + def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf) functionRegistry.registerFunction(name, builder) udaf } + /** + * Register a user-defined function (UDF), for a UDF that's already defined using the DataFrame + * API (i.e. of type UserDefinedFunction). + * + * @param name the name of the UDF. + * @param udf the UDF needs to be registered. + * @return the registered UDF. + * + * @since 2.2.0 + */ + def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { + def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr + functionRegistry.registerFunction(name, builder) + udf + } + // scalastyle:off line.size.limit /* register 0-22 were generated by this script diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ae6b2bc3753fb..6f8723af91cea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -93,6 +93,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } + test("UDF defined using UserDefinedFunction") { + import functions.udf + val foo = udf((x: Int) => x + 1) + spark.udf.register("foo", foo) + assert(sql("select foo(5)").head().getInt(0) == 6) + } + test("ZeroArgument UDF") { spark.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)